blob: 7198e38497e628848c349d6f9126c80d136a25f3 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
"""
# -*- coding: utf-8 -*-
#
# Copyright 2021 Michael Büsch <m@bues.ch>
#
# Licensed under the Apache License version 2.0
# or the MIT license, at your option.
# SPDX-License-Identifier: Apache-2.0 OR MIT
#
"""
from mlplib.init import *
import numpy as np
def test_init_weights():
seed(100)
a = init_weights(100, 200)
assert a.shape == (100, 200)
assert np.all(a >= -1)
assert np.all(a <= 1)
def test_init_biases():
seed(101)
a = init_biases(200)
assert a.shape == (1, 200)
assert np.all(a == 0)
a = init_biases(201, initial=0.5)
assert a.shape == (1, 201)
assert np.all(a == 0.5)
def test_seed():
seed(42)
a = init_weights(10, 20)
seed(42)
b = init_weights(10, 20)
seed(43)
c = init_weights(10, 20)
assert np.all(a == b)
assert np.all(a != c)
# vim: ts=4 sw=4 expandtab
|