-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbiolayer.py
193 lines (157 loc) · 6.68 KB
/
biolayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from time import time
# TODO: use form [batch_size, dim_size] instead of the transposed.
def _compute_step_linear(inputs, synapses, p, delta, k, eps):
with torch.no_grad():
# inputs /= inputs.norm(1) + 1e-10
batch_size, input_size = inputs.shape[:2]
idx_batch = torch.arange(batch_size)
sig = synapses.sign()
tot_input = torch.matmul(sig * synapses.abs().pow(p-1), inputs.t())
values = tot_input.clone()
yl = torch.zeros_like(values)
for i in range(k):
max_activation_idx = torch.argmax(values, 0)
values[max_activation_idx, idx_batch] = -1e10
yl[max_activation_idx, idx_batch] = 1 if i == 0 else -delta
yl[tot_input <= 0] = 0
xx = (yl * tot_input).sum(1)
ds = torch.matmul(yl, inputs) - xx.view(xx.shape[0], 1).repeat(1, input_size) * synapses
nc = ds.abs().max() + 1e-30
return eps * (ds / nc)
class _BioBase:
def __init__(self, p=2, delta=0.4, k=2):
assert p >= 2, 'Lebesgue norm must be greater or equal than 2'
assert k >= 2, 'ranking parameter must be greater or equal than 2'
n_hiddens = getattr(self, 'out_features', None) or self.out_channels
assert k <= n_hiddens, "ranking parameter can't exceed number of hidden units"
self._p = p
self._delta = delta
self._k = k
self.weight.data.normal_()
if self.bias is not None:
self.bias.data.uniform_(-0.25, 0.25)
def train_step(self, inputs, eps):
raise NotImplementedError
def train(self, *args, **kwargs):
ds = args[0]
assert isinstance(ds, torch.Tensor) or isinstance(ds, DataLoader), 'Only dataset as Tensor or DataLoader allowed'
if isinstance(args[0], torch.Tensor):
return self._train_from_tensor(*args, **kwargs)
else:
return self._train_from_dataloader(*args, **kwargs)
def _train_from_tensor(self, train_data, epochs=None, batch_size=100, epsilon=2e-2):
dataset = TensorDataset(train_data)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return self._train_from_dataloader(loader, epochs, epsilon=epsilon)
def _train_from_dataloader(self, loader, epochs=None, epsilon=2e-2):
t0 = time()
max_epochs = 300 if epochs is None else epochs
nep = -1
while epochs is None or nep < epochs:
nep += 1
ep = min(nep, max_epochs-1)
eps = epsilon*(1-ep/max_epochs)
for i, batch_samples in enumerate(loader):
self.train_step(batch_samples[0].to(self.weight.device), eps)
if epochs is None and time()-t0 >= 0.25:
t0 = time()
yield self.weight.data
class BioLinear(_BioBase, nn.Linear):
def __init__(self, in_features, out_features, p=2, delta=0.4, k=2, **kwargs):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
_BioBase.__init__(self, p, delta, k)
def train_step(self, inputs, eps):
if self.bias is None:
synapses = self.weight.data
wdelta = _compute_step_linear(inputs, synapses, self._p, self._delta, self._k, eps)
synapses += wdelta
else:
bias = self.bias.data.unsqueeze(1)
synapses = torch.concat([self.weight.data, bias], dim=1)
inputs = torch.concat([inputs, torch.ones_like(inputs[:, :1])], dim=1)
wdelta = _compute_step_linear(inputs, synapses, self._p, self._delta, self._k, eps)
wdelta, bdelta = wdelta.split([wdelta.size(1) - 1, 1], dim=1)
self.weight.data += wdelta
self.bias.data += bdelta.squeeze(1)
# synapses *= (1 - 1e-5) # TODO: add weight decay
class BioConv2d(_BioBase, nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=True, **kwargs):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation, bias=bias)
_BioBase.__init__(self, **kwargs)
self.weight.data.abs_()
self._in_features = in_channels*kernel_size*kernel_size
self._output_shape = None
def _extract_blocks(self, inputs):
return F.unfold(
inputs, self.kernel_size,
dilation=self.dilation, padding=self.padding, stride=self.stride,
)
def train_step(self, inputs, eps):
assert len(inputs.shape) == 4, 'Inputs must be images with shape [B,C,H,W], {}'.format(inputs.shape)
synapses = self.weight.data
with torch.no_grad():
blocks = self._extract_blocks(inputs)
"""
TODO include random patch selection?
It seems that it generates more diverse patches, as there is
high redundancy intra-image.
"""
random_patches = True
if random_patches:
perm = torch.randperm(blocks.size(2))
idx = perm[:5]
blocks = blocks[:, :, idx]
blocks = blocks.transpose(2, 1).contiguous().view(-1, self._in_features)
syn_flat = synapses.view(synapses.shape[0], -1)
if self.bias is not None:
bias = self.bias.data.unsqueeze(1)
biased_syn = torch.concat([syn_flat, bias], dim=1)
blocks = torch.concat([blocks, torch.ones_like(blocks[:, :1])], dim=1)
synapses = biased_syn
wdelta = _compute_step_linear(blocks, synapses, self._p, self._delta, self._k, eps)
with torch.no_grad():
if self.bias is None:
syn_flat += wdelta
else:
wdelta, bdelta = wdelta.split([wdelta.size(1) - 1, 1], dim=1)
syn_flat += wdelta # + 1e-3 * torch.randn_like(wdelta)
self.bias.data += bdelta.squeeze(1)
def _compute_step_conv(inputs, synapses, stride, padding, dilation, p, delta, k, eps):
with torch.no_grad():
num_units = synapses.shape[0]
N = synapses.shape[1] * synapses.shape[2] * synapses.shape[3] # inputs to neuron
sig = synapses.sign()
tot_input = F.conv2d(inputs, sig * synapses.abs().pow(p-1),
stride=stride, padding=padding, dilation=dilation)
tot_input_flat = tot_input.transpose(0, 1).contiguous().view(num_units, -1) # TODO: try reshape (view)
# tot_input [batch,hid,H2,W2] --> [hid,batch*H2*W2]
Num = tot_input_flat.shape[1] # batch*H2*W2
idx_batch=torch.arange(Num)
values = tot_input_flat.clone()
y1 = torch.argmax(values, 0)
y = y1
for i in range(k-1):
values[y, idx_batch] = -1e10
y = torch.argmax(values, 0)
y2 = y
yl=torch.zeros(num_units, Num) # Por cada neurona, tantas veces como se ejecutó
yl[y1, idx_batch]=1.0 # 1 a la que se activó más
yl[y2, idx_batch]=-delta
xx=(yl*tot_input_flat).sum(1) # [hid]
# [hid,Num] x [Num,D] --> [hid,D] (acumula refuerzo por sinapsis)
# Num ~ batch*H2*W2
# D ~ C*k1*k2
kernel_size = (synapses.shape[2], synapses.shape[3])
blocks = F.unfold(inputs, kernel_size,
stride=stride, padding=padding, dilation=dilation)
blocks = blocks.transpose(2,1).contiguous().view(-1, N)
flat_synapses = synapses.view(num_units, N)
ds = torch.matmul(yl,blocks) - xx.view(num_units,1).repeat(1,N)*flat_synapses
nc = ds.abs().max() + 1e-30
return (eps*(ds/nc)).view(num_units, -1, kernel_size[0], kernel_size[1])