summaryrefslogtreecommitdiffstats
path: root/mlplib/init_test.py
blob: 7198e38497e628848c349d6f9126c80d136a25f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
# -*- coding: utf-8 -*-
#
# Copyright 2021 Michael Büsch <m@bues.ch>
#
# Licensed under the Apache License version 2.0
# or the MIT license, at your option.
# SPDX-License-Identifier: Apache-2.0 OR MIT
#
"""

from mlplib.init import *
import numpy as np

def test_init_weights():
    seed(100)
    a = init_weights(100, 200)
    assert a.shape == (100, 200)
    assert np.all(a >= -1)
    assert np.all(a <= 1)

def test_init_biases():
    seed(101)
    a = init_biases(200)
    assert a.shape == (1, 200)
    assert np.all(a == 0)
    a = init_biases(201, initial=0.5)
    assert a.shape == (1, 201)
    assert np.all(a == 0.5)

def test_seed():
    seed(42)
    a = init_weights(10, 20)
    seed(42)
    b = init_weights(10, 20)
    seed(43)
    c = init_weights(10, 20)
    assert np.all(a == b)
    assert np.all(a != c)

# vim: ts=4 sw=4 expandtab
bues.ch cgit interface