-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtests.py
159 lines (121 loc) · 4.57 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import pytest
import torch
from torch import nn
from liquidnet.main import LiquidNet, MappingType, ODESolver
from liquidnet.vision import VisionLiquidNet
def test_vision_liquid_net_initialization():
num_units = 64
num_classes = 10
model = VisionLiquidNet(num_units, num_classes)
assert isinstance(model, nn.Module)
assert isinstance(model.liquid_net, LiquidNet)
def test_vision_liquid_net_forward_pass():
num_units = 64
num_classes = 10
model = VisionLiquidNet(num_units, num_classes)
batch_size = 8
channels = 3
height = 32
width = 32
input_tensor = torch.randn(batch_size, channels, height, width)
output = model(input_tensor)
assert output.shape == (batch_size, num_classes)
def test_hidden_state_initialization():
num_units = 64
num_classes = 10
model = VisionLiquidNet(num_units, num_classes)
batch_size = 8
channels = 3
height = 32
width = 32
input_tensor = torch.randn(batch_size, channels, height, width)
model(input_tensor)
assert model.hidden_state is not None
@pytest.fixture
def vision_liquid_net():
return VisionLiquidNet(num_units=64, num_classes=10)
@pytest.fixture
def liquid_net():
return LiquidNet(num_units=64)
def test_vision_liquid_net_forward(vision_liquid_net):
batch_size = 4
input_channels = 3
input_height = 32
input_width = 32
num_classes = 10
inputs = torch.randn(batch_size, input_channels, input_height, input_width)
outputs = vision_liquid_net(inputs)
assert outputs.shape == (batch_size, num_classes)
def test_vision_liquid_net_hidden_state(vision_liquid_net):
batch_size = 4
input_channels = 3
input_height = 32
input_width = 32
inputs = torch.randn(batch_size, input_channels, input_height, input_width)
assert vision_liquid_net.hidden_state is None
_ = vision_liquid_net(inputs)
assert vision_liquid_net.hidden_state is not None
def test_liquid_net_forward(liquid_net):
batch_size = 4
input_size = 32
num_units = liquid_net.state_size
inputs = torch.randn(batch_size, input_size)
initial_state = torch.zeros(batch_size, num_units)
outputs, final_state = liquid_net(inputs, initial_state)
assert outputs.shape == (batch_size, num_units)
assert final_state.shape == (batch_size, num_units)
def test_liquid_net_parameter_constraints(liquid_net):
constraints = liquid_net.get_param_constrain_op()
for param in constraints:
assert (param >= 0).all()
NUM_UNITS = 64
BATCH_SIZE = 4
INPUT_SIZE = 32
NUM_ITERATIONS = 100
@pytest.fixture
def liquid_net():
return LiquidNet(NUM_UNITS)
@pytest.fixture
def sample_inputs():
return torch.randn(BATCH_SIZE, INPUT_SIZE)
@pytest.fixture
def initial_state():
return torch.zeros(BATCH_SIZE, NUM_UNITS)
def test_liquid_net_initialization(liquid_net):
assert liquid_net.state_size == NUM_UNITS
assert liquid_net.output_size == NUM_UNITS
def test_forward_pass(liquid_net, sample_inputs, initial_state):
outputs, final_state = liquid_net(sample_inputs, initial_state)
assert outputs.shape == (BATCH_SIZE, NUM_UNITS)
assert final_state.shape == (BATCH_SIZE, NUM_UNITS)
def test_variable_constraints(liquid_net):
constraining_ops = liquid_net.get_param_constrain_op()
for op in constraining_ops:
assert torch.all(op >= 0)
def test_export_weights(liquid_net):
dirname = "test_weights"
liquid_net.export_weights(dirname)
assert os.path.exists(os.path.join(dirname, "w.csv"))
assert os.path.exists(os.path.join(dirname, "erev.csv"))
assert os.path.exists(os.path.join(dirname, "mu.csv"))
assert os.path.exists(os.path.join(dirname, "sigma.csv"))
assert os.path.exists(os.path.join(dirname, "sensory_w.csv"))
assert os.path.exists(os.path.join(dirname, "sensory_erev.csv"))
assert os.path.exists(os.path.join(dirname, "sensory_mu.csv"))
assert os.path.exists(os.path.join(dirname, "sensory_sigma.csv"))
assert os.path.exists(os.path.join(dirname, "vleak.csv"))
assert os.path.exists(os.path.join(dirname, "gleak.csv"))
assert os.path.exists(os.path.join(dirname, "cm.csv"))
@pytest.mark.parametrize("solver", [ODESolver.SemiImplicit, ODESolver.Explicit])
@pytest.mark.parametrize(
"mapping_type", [MappingType.Identity, MappingType.Linear, MappingType.Affine]
)
def test_solver_and_mapping_types(
liquid_net, sample_inputs, initial_state, solver, mapping_type
):
liquid_net._solver = solver
liquid_net._input_mapping = mapping_type
outputs, final_state = liquid_net(sample_inputs, initial_state)
if __name__ == "__main__":
pytest.main()