Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Writing custom log_softmax converter #865

Open
gabe-scorebreak opened this issue May 30, 2023 · 0 comments
Open

Writing custom log_softmax converter #865

gabe-scorebreak opened this issue May 30, 2023 · 0 comments

Comments

@gabe-scorebreak
Copy link

Hi, I want to convert torch.nn.functional.log_softmax to trt, however, some of my tests seem to fail and I don't know why.

This is what I came up with:

from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter("torch.Tensor.log_softmax")
@tensorrt_converter("torch.nn.functional.log_softmax")
def convert_log_softmax(ctx):
    input = ctx.method_args[0]
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
    output = ctx.method_return

    # get dims from args or kwargs
    if "dim" in ctx.method_kwargs:
        dim = ctx.method_kwargs["dim"]
    elif len(ctx.method_args) >= 2:
        dim = ctx.method_args[1]

    # convert negative dims
    if dim < 0:
        dim = len(input.shape) + dim

    axes = torch_dim_to_trt_axes(dim)

    layer = ctx.network.add_softmax(input=input_trt)
    layer.axes = axes
    layer = ctx.network.add_unary(input=layer.get_output(0), op=trt.UnaryOperation.LOG)

    output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module():
    return torch.nn.LogSoftmax(1)

@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_dim2():
    return torch.nn.LogSoftmax(2)

@add_module_test(torch.float32, torch.device("cuda"), [(1, 3)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_neg1():
    return torch.nn.LogSoftmax(-1)

@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_dim_neg2():
    return torch.nn.LogSoftmax(-2)

This is not very original, I just took the implementation of softmax and added the line layer = ctx.network.add_unary(input=layer.get_output(0), op=trt.UnaryOperation.LOG) which is what log softmax is supposed to do. Well, tests fail

|               torch2trt.converters.log_softmax.test_log_softmax_module | float32 |            [(1, 3, 3, 3)] | {} | 2.38E-07 | 155.45 | 3.77E-15 | 8.48e+04 | 1.19e+04 | 0.0893 | 0.178 |
|               torch2trt.converters.log_softmax.test_log_softmax_module | float32 |                  [(1, 3)] | {} | 1.99E+00 | 2.97 | 2.01E+00 | 8.02e+04 | 1.25e+04 | 0.0917 | 0.192 |
|          torch2trt.converters.log_softmax.test_log_softmax_module_dim2 | float32 |            [(1, 3, 3, 3)] | {} | 1.86E+00 | 11.72 | 8.45E-01 | 7.96e+04 | 1.16e+04 | 0.0918 | 0.188 |
|          torch2trt.converters.log_softmax.test_log_softmax_module_neg1 | float32 |            [(1, 3, 3, 3)] | {} | 2.13E+00 | 11.94 | 8.13E-01 | 7.93e+04 | 1.18e+04 | 0.0987 | 0.184 |
|          torch2trt.converters.log_softmax.test_log_softmax_module_neg1 | float32 |                  [(1, 3)] | {} | 3.49E+00 | 2.71 | 6.53E+00 | 8.4e+04 | 1.23e+04 | 0.0999 | 0.188 |
|      torch2trt.converters.log_softmax.test_log_softmax_module_dim_neg2 | float32 |            [(1, 3, 3, 3)] | {} | 1.45E+00 | 11.10 | 6.35E-01 | 7.47e+04 | 1.23e+04 | 0.111 | 0.191 |
NUM_TESTS: 6
NUM_SUCCESSFUL_CONVERSION: 6
NUM_FAILED_CONVERSION: 0
NUM_ABOVE_TOLERANCE: 5
NUM_pSNR_TOLERANCE: 0

Interestingly enough one tests passes, other five fail. What I'm worried about is that the errors are quite large. At first I thought maybe using some more clever formula like logsoftmax(x) = (x - x_max) - log(sum(exp((x - x_max)))) would perform better here, I tried my luck implementing it, but to no avail, still one test passes, the rest fails. But even in the current implementation I wouldn't expect such extreme errors, this tells me that something is wrong on a fundamental level. I would appreciate any help @jaybdub
I will gladly create a PR if you help me get it working

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant