Skip to content

Commit

Permalink
Support loading state dicts from post-Pytorch 1.0 standard BN models
Browse files Browse the repository at this point in the history
  • Loading branch information
ducksoup committed Aug 14, 2019
1 parent 5c3a083 commit 5d0dd8e
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions inplace_abn/abn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,21 @@ def forward(self, x):
else:
raise RuntimeError("Unknown activation function {}".format(self.activation))

def __repr__(self):
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
' affine={affine}, activation={activation}'
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
# Post-Pytorch 1.0 models using standard BatchNorm have a "num_batches_tracked" parameter that we need to ignore
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(ABN, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
error_msgs, unexpected_keys)

def extra_repr(self):
rep = '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, activation={activation}'
if self.activation in ["leaky_relu", "elu"]:
rep += '[{activation_param}])'
else:
rep += ')'
return rep.format(name=self.__class__.__name__, **self.__dict__)
rep += '[{activation_param}]'
return rep.format(**self.__dict__)


class InPlaceABN(ABN):
Expand Down

0 comments on commit 5d0dd8e

Please sign in to comment.