diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 9126f1e..f585219 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -1386,11 +1386,17 @@ def parse(self, node, attrs, args, graph_converter): assert dilation_h == dilation_w == 1, "Only dilation == 1 is supported" - padding = tfl_schema.Padding.VALID + add_pad_op = not ( + stride_h == stride_w == 1 and pad_h == kernel_h // 2 and pad_w == kernel_w // 2 and not ceil_mode + ) + padding = tfl_schema.Padding.SAME + if add_pad_op: + padding = tfl_schema.Padding.VALID maxpool_op = tfl.MaxPool2dOperator(inputs, outputs, padding, stride_w, stride_h, kernel_w, kernel_h) ops = self.wrap_ops_with_nhwc_nchw_transposes([maxpool_op]) - self.handle_padding(pad_h, pad_w, 1, ops, ceil_mode) + if add_pad_op: + self.handle_padding(pad_h, pad_w, 1, ops, ceil_mode) for op in ops: graph_converter.add_operator(op)