Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does torch2trt support conversion on torch.nn.functional.avg_pool2d? #457

Open
oliviawindsir opened this issue Nov 27, 2020 · 4 comments
Open

Comments

@oliviawindsir
Copy link

oliviawindsir commented Nov 27, 2020

Hi guys, I need help to find out more about the error of unsupported argument types invoked by IPoolingLayer.

I was trying to convert a model which consist of resnet50 + Deepmar by this repo and save the serialized engine:

However, torch2trt is throwing error below when converting torch.nn.functional.avg_pool2d. I checked the codebase and seems like torch2trt do indeed support the conversion for this layer.

I would like to seek for guidance/help on pinpointing what is the problem with converting torch.nn.functional.avg_pool2d. Perhaps someone can point me on where to start to debug this problem? Thanks a bunch!

Error: (): incompatible function arguments. The following argument types are supported:
    1. (arg0: tensorrt.tensorrt.IPoolingLayer, arg1: tensorrt.tensorrt.Dims) -> None

Invoked with: <tensorrt.tensorrt.IPoolingLayer object at 0x7fbdfe28b768>, (None, None)

Extra info:

  • torch v1.5.0
  • torchvision 0.5.0
  • TRT v7.0.0.11

Snippet of code that calls the conversion:

            img = torch.zeros([1, 3, 512, 512]).cuda()
            try:
                model_trt = torch2trt.torch2trt(model_pytorch, [img])
                engine_path = "deepmar_only_test.engine"
                with open(engine_path, "wb") as f:
                    f.write(model_trt.engine.serialize())
                model_deepmar = model_trt
                # model_deepmar.eval()
                # model_deepmar.cuda()
            except Exception as e:
                print("Error: %s" % e)
@oliviawindsir oliviawindsir changed the title Error of unsupported arguments type invoked by IPoolingLayer Does torch2trt support conversion on torch.nn.functional.avg_pool2d? Nov 30, 2020
@oliviawindsir
Copy link
Author

Hmm, okay I managed to fixed this issue. It's due to the default stride value which is set to None by default. Based on pytorch's documentation, default stride value should be the same as kernel_size if not specified.

With that, I did the following changes in torch2trt/torch2trt/converters/avg_pool.py

@tensorrt_converter('torch.nn.functional.avg_pool2d', enabled=trt_version() >= '7.0')
@tensorrt_converter('torch.nn.functional.avg_pool3d', enabled=trt_version() >= '7.0')
def convert_avg_pool_trt7(ctx):
    ...
    # get stride. 
    # Add the following lines.
    # If none, default to kernel size
    if stride == None:
        stride = kernel_size
    if not isinstance(stride, tuple):
        stride = (stride, ) * input_dim
    ...

Hope it help! Cheers.

@Cheahom
Copy link

Cheahom commented Jun 23, 2021

Hello, my TensorRT version is 6.0.1.5 and i encountered the same problem.

TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (arg0: tensorrt.tensorrt.IPoolingLayer, arg1: tensorrt.tensorrt.DimsHW) -> None

Invoked with: <tensorrt.tensorrt.IPoolingLayer object at 0x7f2c90534270>, (None, None)

I have changed:torch2trt/torch2trt/converters/avg_pool.py like

...
@tensorrt_converter("torch.nn.functional.avg_pool2d", enabled=trt_version() < '7.0')
def convert_avg_pool2d(ctx):
    ...
    # get stride
    if stride == None:
        stride = kernel_size
        # stride = 2
    if not isinstance(stride, tuple):
        # stride = (stride,) * 2
        # stride = 2
        stride = (stride, ) * input.dim() - 2
    ...
@tensorrt_converter('torch.nn.functional.avg_pool3d', enabled=trt_version() >= '7.0')
def convert_avg_pool_trt7(ctx):
    ...
    # get stride
    if stride == None:
        stride = kernel_size
    if not isinstance(stride, tuple):
        stride = (stride, ) * input_dim
    ...

but it still not working...

@abhigoku10
Copy link

@Cheahom @oliviawindsir i tried your steps but did not solve it , but i am getting new error
Warning: Encountered known unsupported method torch.zeros
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_unary
Warning: Encountered known unsupported method torch.Tensor.is_floating_point
Warning: Encountered known unsupported method torch.affine_grid_generator
Warning: Encountered known unsupported method torch.nn.functional.affine_grid
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_variadic
Warning: Encountered known unsupported method torch.grid_sampler
Warning: Encountered known unsupported method torch.nn.functional.grid_sample
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.zeros
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_unary
Warning: Encountered known unsupported method torch.Tensor.is_floating_point
Warning: Encountered known unsupported method torch.affine_grid_generator
Warning: Encountered known unsupported method torch.nn.functional.affine_grid
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_variadic
Warning: Encountered known unsupported method torch.grid_sampler
Warning: Encountered known unsupported method torch.nn.functional.grid_sample
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.zeros
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_unary
Warning: Encountered known unsupported method torch.Tensor.is_floating_point
Warning: Encountered known unsupported method torch.affine_grid_generator
Warning: Encountered known unsupported method torch.nn.functional.affine_grid
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_variadic
Warning: Encountered known unsupported method torch.grid_sampler
Warning: Encountered known unsupported method torch.nn.functional.grid_sample
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.zeros
Warning: Encountered known unsupported method torch.Tensor.cuda
Warning: Encountered known unsupported method torch.nn.functional.has_torch_function_unary
Warning: Encountered known unsupported method torch.Tensor.is_floating_

@ykk648
Copy link

ykk648 commented Sep 28, 2021

check padding or what else may cause error:

# get padding
if not isinstance(padding, tuple):
    padding = (padding[0], ) * input_dim

I fixed it by changing padding codes in torch2trt/torch2trt/converters/avg_pool.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants