diff --git a/tinynn/graph/quantization/algorithm/cross_layer_equalization.py b/tinynn/graph/quantization/algorithm/cross_layer_equalization.py index 1afd864f..85001c9e 100644 --- a/tinynn/graph/quantization/algorithm/cross_layer_equalization.py +++ b/tinynn/graph/quantization/algorithm/cross_layer_equalization.py @@ -354,5 +354,5 @@ def model_rewrite(model, dummy_input, work_dir='out') -> nn.Module: def clear_model_fused_bn(model: nn.Module): """remove the attached bn from fused conv""" for mod in model.modules(): - if isinstance(mod, nn.Conv2d) and hasattr(mod, 'fused_bn_'): + if isinstance(mod, (nn.Conv2d, nn.ConvTranspose2d)) and hasattr(mod, 'fused_bn_'): delattr(mod, 'fused_bn_')