-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample_conv.py
41 lines (35 loc) · 1.12 KB
/
example_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from biolayer import BioConv2d
from visualization import LayerVisualizer
CIFAR_DIR = '~/DataSets/'
batch_size = 50
learning_rate = 2e-3
kernel_size = 5
num_neurons = 10 * 10
"""
Hebbian learning works best with positive inputs, which is also optimal when ReLU
activations are employed in subsequent layers. This is why inputs are normalized to the
range [0, 1] and not whitened.
"""
cifar = CIFAR10(CIFAR_DIR)
with torch.no_grad():
cifar_data = torch.tensor(cifar.data.view(), dtype=torch.float) / 255.
cifar_data = cifar_data.transpose(1, 3).transpose(2, 3) # [B,H,W,C] --> [B,C,H,W]
bio_conv = BioConv2d(3, num_neurons, kernel_size, stride=3).cuda()
weights = bio_conv.weight[:, :3, :, :].permute((0, 2, 3, 1))
vis = LayerVisualizer(weights)
avgs = []
stds = []
try:
for weight in bio_conv.train(cifar_data, batch_size=batch_size, epsilon=learning_rate):
vis.update()
avgs.append(weight.mean().item())
stds.append(weight.std().item())
except KeyboardInterrupt:
pass
plt.plot(avgs, label="avg")
plt.plot(stds, label="std")
plt.legend()
plt.show()