""" # -*- coding: utf-8 -*- # # Copyright 2021 Michael Büsch # # Licensed under the Apache License version 2.0 # or the MIT license, at your option. # SPDX-License-Identifier: Apache-2.0 OR MIT # """ from mlplib.init import * import numpy as np def test_init_weights(): seed(100) a = init_weights(100, 200) assert a.shape == (100, 200) assert np.all(a >= -1) assert np.all(a <= 1) def test_init_biases(): seed(101) a = init_biases(200) assert a.shape == (1, 200) assert np.all(a == 0) a = init_biases(201, initial=0.5) assert a.shape == (1, 201) assert np.all(a == 0.5) def test_seed(): seed(42) a = init_weights(10, 20) seed(42) b = init_weights(10, 20) seed(43) c = init_weights(10, 20) assert np.all(a == b) assert np.all(a != c) # vim: ts=4 sw=4 expandtab