Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hardswish example #426

Open
samiit opened this issue Oct 12, 2020 · 7 comments
Open

Hardswish example #426

samiit opened this issue Oct 12, 2020 · 7 comments

Comments

@samiit
Copy link

samiit commented Oct 12, 2020

Can anyone kindly comment whether my implementation of hardswish is correct? I can well imagine that it is not the most efficient, but the best I could get to.

import tensorrt as trt
from torch2trt import tensorrt_converter
from torch2trt.torch2trt import *

@tensorrt_converter('torch.nn.Hardswish.forward')
@tensorrt_converter('torch.nn.functional.hardswish')
def convert_Hardswish(ctx):
    # h-swish(x) = x * ReLU6(x+3)/6 # source: https://paperswithcode.com/method/hard-swish
    input = ctx.method_args[1]
    output = ctx.method_return
    
    input_a_trt, input_b_trt, input_c_trt, input_d_trt = add_missing_trt_tensors(ctx.network, [input, 6., 1./6., 3.])
    input_a_trt, input_b_trt, input_c_trt, input_d_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt, input_c_trt, input_d_trt], len(output.shape) - 1)
    
    # ReLU6(x+3)
    layer = ctx.network.add_elementwise( input_a_trt, input_d_trt, trt.ElementWiseOperation.SUM)
    layer = ctx.network.add_activation( input=layer.get_output(0), type=trt.ActivationType.RELU )
    layer = ctx.network.add_elementwise( layer.get_output(0), input_b_trt, trt.ElementWiseOperation.MIN )
    
    # ReLU6(x+3)/6
    layer = ctx.network.add_elementwise( layer.get_output(0), input_c_trt, trt.ElementWiseOperation.PROD )
    
    # x*ReLU6(x+3)/6
    layer = ctx.network.add_elementwise( input_a_trt, layer.get_output(0), trt.ElementWiseOperation.PROD )
    
    output._trt = layer.get_output(0)

Thanks for the great repo!

Sam

@HamsterHuey
Copy link

I think this may be incorrect for the torch.nn.functional.hardswish implementation since the relevant input for the functional form would be ctx.method_args[0]. See the implementation of plugins for Relu6 class vs functional form as an example:

https://github.com/NVIDIA-AI-IOT/torch2trt/blob/b0cc8e77a0fbd61e96b971a66bbc11326f77c6b5/torch2trt/converters/ReLU6.py

https://github.com/NVIDIA-AI-IOT/torch2trt/blob/b0cc8e77a0fbd61e96b971a66bbc11326f77c6b5/torch2trt/converters/relu6.py

@Jason-Lee0
Copy link

@samiit Hi , Did your Hardswish.py implementation successful ?

I am facing the same problem about Hardswish and hardsigmoid . Did you have any advice to compile these blocks ?

Sorry to bother you ~~
Thanks .

@HamsterHuey
Copy link

@ntut108318099 - I took the lazy way out and just wrote the Hardswish implementation in Pytorch and then replaced all the activations in the model with my version and it then converts just fine:

# Swish, HardSigmoid and HardSwish are adapted from:
# https://github.com/Randl/MobileNetV3-pytorch/blob/master/MobileNetV3.py
def swish(x: Tensor) -> Tensor:
    return x * x.sigmoid()


def hard_sigmoid(x: Tensor, inplace: bool = False) -> Tensor:
    return F.relu6(x + 3, inplace) / 6


def hard_swish(x: Tensor, inplace: bool = False) -> Tensor:
    return x * hard_sigmoid(x, inplace)


class Hardswish2(nn.Module):
    def __init__(self, inplace=False):
        super(Hardswish2, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_swish(x, inplace=self.inplace)

@Jason-Lee0
Copy link

@HamsterHuey Thanks for your reply.
So you modified the block in Mobilenetv3 module, not in the torch2trt module ?

That’s amazing idea !
I will try the code you shared .

Thanks for your supply .
Have a nice day !

@HamsterHuey
Copy link

@ntut108318099 - I just copied that code from the Mobilenet repo into my own. Basically you define Hardswish activation in Pytorch primitives. Because the HardSwish here is defined in terms of F.relu6 the torch2trt conversion works because it ships with an adapter for F.relu6. So all you need to do are:

  1. Implement your own HardSwish activation using the code I posted above
  2. Replace any torch.nn.Hardswish activations with your custom one. Either do this before instantiating the model if you can directly edit the code for the model definitions. Alternatively, you can loop through each module in a model and dynamically replace the torch.nn.Hardswish activations with the new definition. Once you do that, torch2trt will be able to handle the conversion without needing to make any edits to the torch2trt library.

@Jason-Lee0
Copy link

@HamsterHuey Hi, I followed your tutorial,and it worked !!
The tensorrt_model looks very well . Thanks for your big help .

Have a nice day .

@Chelovek760
Copy link

@ntut108318099 Hi! Could you please share your solution because my implementation returns nan?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants