Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encountered known unsupported method torch.nn.functional.pixel_shuffle #493

Open
zlheos opened this issue Jan 24, 2021 · 6 comments
Open

Comments

@zlheos
Copy link

zlheos commented Jan 24, 2021

i convert alphapose fast_res50_256x192.pth model, but i encounter the unsupported layer ?

@wwdok
Copy link

wwdok commented Feb 4, 2021

Hi, bro, i am doing exactly the same thing, how is it going on now ?

@jaybdub
Copy link
Contributor

jaybdub commented Feb 19, 2021

Hi All,

Thanks for reaching out!

It may be possible to implement this layer using the TensorRT Python API.

https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#ishufflelayer

Some documentation for how to add a custom converter to torch2trt is detailed here

https://nvidia-ai-iot.github.io/torch2trt/master/usage/custom_converter.html

I haven't personally done this yet for the pixel_shuffle layer, but I imagine it is possible.

Let me know if this helps or you run into any issues, or have any questions.

Best,
John

@owoshch
Copy link

owoshch commented Aug 16, 2021

Hi @zlheos @wwdok! Have you found the solution to convert pixel_shuffle layer to tensorrt?

@464hee
Copy link

464hee commented Jul 19, 2023

Hello, have you made a relevant implementation for this method yet?

@464hee
Copy link

464hee commented Jul 19, 2023

I tried the following code and the warning contacts, but the predictions don't match the pytorch model predictions at all

import tensorrt as trt
from torch2trt import tensorrt_converter, add_missing_trt_tensors
import torch.nn.functional as F

@tensorrt_converter('torch.nn.functional.pixel_shuffle')
def convert_pixel_shuffle(ctx):
input = ctx.method_args[0]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
@jaybdub

@timegoby
Copy link

I have refered the code from https://github.com/NVIDIA-AI-IOT/torch2trt/issues/612, and it works for me.

@tensorrt_converter('torch.nn.functional.pixel_shuffle')
def convert_PixelShuffle(ctx):
input = ctx.method_args[0]
scale_factor = ctx.method_args[1]

input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return

batch_size, in_channels, in_height, in_width = input.shape

assert scale_factor >= 1

out_channels = in_channels // (scale_factor * scale_factor)
out_height = in_height * scale_factor
out_width = in_width * scale_factor

layer_1 = ctx.network.add_shuffle(input_trt)
layer_1.reshape_dims = (out_channels, scale_factor, scale_factor, in_height, in_width)

layer_2 = ctx.network.add_shuffle(layer_1.get_output(0))
layer_2.first_transpose = (0, 3, 1, 4, 2)
layer_2.reshape_dims = (batch_size, out_channels, out_height, out_width)

output._trt = layer_2.get_output(0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants