""" # -*- coding: utf-8 -*- # # Copyright 2021 Michael Büsch # # Licensed under the Apache License version 2.0 # or the MIT license, at your option. # SPDX-License-Identifier: Apache-2.0 OR MIT # """ __all__ = [ "BackpropGrad", "BackpropGrads", "backward_prop", ] from mlplib.forward import forward_prop from mlplib.loss import Loss from mlplib.parameters import Parameters from mlplib.util import GenericIter from collections import deque, namedtuple from dataclasses import dataclass, field from typing import Callable, Optional, Tuple import numpy as np BackpropGrad = namedtuple("BackpropGrad", ["dw", "db"]) @dataclass class BackpropGrads(object): """ Calculated backpropagation gradients. """ dw: deque[np.ndarray] = field(default_factory=deque) db: deque[np.ndarray] = field(default_factory=deque) def __iter__(self): return BackpropGradsIter(self, len(self.dw)) def __reversed__(self): return BackpropGradsIter(self, len(self.dw), True, len(self.dw) - 1) @dataclass class BackpropGradsIter(GenericIter): def __next__(self): obj, pos = self._next() return BackpropGrad( dw=obj.dw[pos], db=obj.db[pos], ) def backward_prop(x: np.ndarray, y: np.ndarray, params: Parameters, loss: Loss)\ -> Tuple[BackpropGrads, np.ndarray]: assert len(params.weights) >= 1 assert len(params.weights) == len(params.biases) assert len(params.weights) == len(params.actvns) # Number of samples. m = x.shape[0] # Run the network in forward direction. yh, netstate = forward_prop(x, params, store_netstate=True) assert isinstance(netstate, list) assert len(params.weights) == len(netstate) assert x.shape[0] == yh.shape[0] # Calculate the net output loss derivative. da = loss.fn_d(yh, y) assert da.shape == (m, y.shape[1]) grads = BackpropGrads() last_layer = len(params.weights) - 1 for l_rev, ((w, _b, actv, *_), state) in\ enumerate(zip(reversed(params), reversed(netstate))): assert w.ndim == 2 and w.shape[0] == state.x.shape[1] assert da.ndim == 2 and da.shape == (m, w.shape[1]) # Calculate the neuron backward propagation. dw, db, da = actv.backward_prop(w, da, state.x, state.z, m, l_rev != last_layer) # Store the calculated gradients. grads.dw.appendleft(dw) grads.db.appendleft(db) assert len(grads.dw) == len(params.weights) assert len(grads.db) == len(params.weights) return grads, yh # vim: ts=4 sw=4 expandtab