""" # -*- coding: utf-8 -*- # # Copyright 2021 Michael Büsch # # Licensed under the Apache License version 2.0 # or the MIT license, at your option. # SPDX-License-Identifier: Apache-2.0 OR MIT # """ __all__ = [ "Parameter", "Parameters", ] from collections import namedtuple from dataclasses import dataclass from mlplib.activation import Activation from mlplib.util import GenericIter from typing import List, Tuple, Callable, Optional import numpy as np Parameter = namedtuple("Parameter", ["w", "b", "actv"]) @dataclass class Parameters(object): """weights: List of np.array with shape (prev_n, n). biases: List of np.array with shape (1, n). actvns: List of activation functions. """ weights: List[np.ndarray] biases: List[np.ndarray] actvns: List[Activation] @property def layout(self) -> Tuple[int, ...]: return tuple(w.shape[1] for w in self.weights) @property def nr_inputs(self) -> int: assert len(self.weights) >= 1 return self.weights[0].shape[0] @property def nr_outputs(self) -> int: assert len(self.weights) >= 1 return self.weights[-1].shape[1] def __iter__(self): return ParametersIter(self, len(self.weights)) def __reversed__(self): return ParametersIter(self, len(self.weights), True, len(self.weights) - 1) def __str__(self): ret = [] for i, (w, b, a) in enumerate(zip(self.weights, self.biases, self.actvns)): wstr = "\n ".join(str(w).splitlines()) bstr = "\n ".join(str(b).splitlines()) ret.append(f"L{i} w: {wstr}") ret.append(f" b: {bstr}") ret.append(f" {a.__class__.__name__}") return "\n".join(ret) @dataclass class ParametersIter(GenericIter): def __next__(self): obj, pos = self._next() return Parameter( w=obj.weights[pos], b=obj.biases[pos], actv=obj.actvns[pos], ) # vim: ts=4 sw=4 expandtab