""" # -*- coding: utf-8 -*- # # Copyright 2021 Michael Büsch # # Licensed under the Apache License version 2.0 # or the MIT license, at your option. # SPDX-License-Identifier: Apache-2.0 OR MIT # """ __all__ = [ "Loss", "MAE", "MSE", ] from abc import ABC, abstractmethod import numpy as np class Loss(ABC): @abstractmethod def fn(self, yh, y): """Forward loss function. yh: predicted value y: expected value """ @abstractmethod def fn_d(self, yh, y): """Loss function derivative. """ class MAE(Loss): """MAE Mean Absolute Error (L1) loss. """ def fn(self, yh, y): assert yh.size == y.size if y.size: return np.absolute(y - yh).mean() return 0.0 def fn_d(self, yh, y): assert yh.size == y.size return ((yh > y).astype(y.dtype) * 2.0) - 1.0 class MSE(Loss): """MSE Mean Squared Error (L2) loss. """ def fn(self, yh, y): assert yh.size == y.size if y.size: return np.square(y - yh).mean() return 0.0 def fn_d(self, yh, y): assert yh.size == y.size return yh - y # vim: ts=4 sw=4 expandtab