Skip to content

Commit

Permalink
fix silly mistakes that would make the network not learn
Browse files Browse the repository at this point in the history
  • Loading branch information
guidopetri committed Nov 1, 2020
1 parent ea5b263 commit fb282be
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions csgo_wp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(self,
self.blocks = torch.nn.ModuleList()

for input_size, output_size in pairwise(hidden_sizes):
# last unit: don't use activation
if output_size == self.output_size:
activation = 'Identity'
block = ResidualBlock(input_size,
activation,
activation_params)
Expand Down Expand Up @@ -221,10 +224,6 @@ def __init__(self,
activation,
activation_params)

if self.bn_activated:
self.norm = torch.nn.BatchNorm1d(num_features=self.output_size)
else:
self.norm = torch.nn.Identity()
# sigmoid
self.sigmoid = torch.nn.Sigmoid()

Expand All @@ -238,7 +237,6 @@ def forward(self, x):
x = x.reshape(-1, self.num_elements_output)

x = self.linear(x)
x = self.norm(x)
x = self.sigmoid(x)

return x.squeeze(1)
Expand Down Expand Up @@ -271,17 +269,20 @@ def __init__(self,
self.linear_blocks = torch.nn.ModuleList()

for input_size, output_size in pairwise(hidden_sizes):
# last unit: don't use activation
if output_size == self.output_size:
activation = 'Identity'
block = LinearBlock(input_size,
output_size,
activation,
activation_params)
self.linear_blocks.append(block)

if self.bn_activated:
if self.bn_activated and output_size != self.output_size:
norm = torch.nn.BatchNorm1d(num_features=output_size)
self.linear_blocks.append(norm)

if self.dropout_activated:
if self.dropout_activated and output_size != self.output_size:
dropout_block = torch.nn.Dropout()
self.linear_blocks.append(dropout_block)

Expand Down

0 comments on commit fb282be

Please sign in to comment.