Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Beichen-Ma committed Mar 20, 2024
1 parent 6705dc3 commit cf5d5cf
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions src/size_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,15 @@ def hook(module, input, output):
bias_ops = output_features if module.bias is not None else 0
flops[module] = batch_size * (2 * input_features * output_features + bias_ops)

# if isinstance(module, nn.Conv2d):
# in_channels = input[0].shape[1]
# out_channels = output.shape[1]
# out_h = output.shape[2]
# out_w = output.shape[3]
# k_ops = 2 * module.kernel_size[0] * module.kernel_size[1] * (in_channels // module.groups)
# bias_ops = out_channels if module.bias is not None else 0
# flops[module] = batch_size * (k_ops * out_h + bias_ops) * (out_w + out_h * out_w)
if isinstance(module, nn.Conv2d):
in_channels = input[0].shape[1]
out_channels = output.shape[1]
out_h = output.shape[2]
out_w = output.shape[3]
kernel_elements = module.kernel_size[0] * module.kernel_size[1]
k_ops = kernel_elements * (in_channels // module.groups)
bias_ops = out_channels if module.bias is not None else 0
# Corrected FLOPs calculation for Conv2d
flops[module] = batch_size * out_h * out_w * out_channels * (k_ops * 2 + bias_ops)
flops[module] = 0
if module.bias is not None:
bias_flops = np.prod(output.shape[1:]) * batch_size
flops[module] += bias_flops

layer_shape = list(module.weight.size())
weight_flops = np.prod(output.shape[1:]) * 2 * layer_shape[1] * layer_shape[2] * layer_shape[3] * batch_size
flops[module] += weight_flops

if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
num_elements = input[0].numel()
Expand Down

0 comments on commit cf5d5cf

Please sign in to comment.