Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing dtype promotion in where #1734

Merged
merged 10 commits into from
Feb 4, 2025
Merged

fixing dtype promotion in where #1734

merged 10 commits into from
Feb 4, 2025

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Feb 1, 2025

Fixes #1723

nvfuser type promotion rule for where is different from thunder. We could accidentally create output tensor in double as noted in #1723 (comment).

@jjsjann123 jjsjann123 marked this pull request as ready for review February 1, 2025 04:59
Copy link
Collaborator

@beverlylytle beverlylytle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! thanks!

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, Current fix works because there is another where later with tensor inputs which correctly casts the output later.

I think we should also handle the following case (which still fails).

  • Fusion with only prims.where(cond, 1., 0.) symbol
import thunder
import thunder.examine
import torch

def fn(c, x, y):
    return torch.where(c, x, y)

x = 1.
y = 0.

print(fn(torch.rand(3, 3, device='cuda') > 0.5, x, y).dtype)  # torch.float

jfn = thunder.jit(fn, nv_store_fusion_inputs=True)
print(jfn(torch.rand(3, 3, device='cuda') > 0.5, x, y).dtype)  # torch.double

# trc = thunder.last_traces(jfn)[-1]

# print(trc)
# fusion_syms = thunder.examine.get_fusion_symbols(trc)

# bsym = fusion_syms[0]

# repro = thunder.examine.get_nvfuser_repro(trc, bsym.sym.name)
# print(repro)

@mruberry
Copy link
Collaborator

mruberry commented Feb 3, 2025

IIUC, Current fix works because there is another where later with tensor inputs which correctly casts the output later.

I think we should also handle the following case (which still fails).

  • Fusion with only prims.where(cond, 1., 0.) symbol
import thunder
import thunder.examine
import torch

def fn(c, x, y):
    return torch.where(c, x, y)

x = 1.
y = 0.

print(fn(torch.rand(3, 3, device='cuda') > 0.5, x, y).dtype)  # torch.float

jfn = thunder.jit(fn, nv_store_fusion_inputs=True)
print(jfn(torch.rand(3, 3, device='cuda') > 0.5, x, y).dtype)  # torch.double

# trc = thunder.last_traces(jfn)[-1]

# print(trc)
# fusion_syms = thunder.examine.get_fusion_symbols(trc)

# bsym = fusion_syms[0]

# repro = thunder.examine.get_nvfuser_repro(trc, bsym.sym.name)
# print(repro)

Interesting! I'm surprised this case still fails with this fix. Maybe we need more tests cases for nvfuser and where, too

@jjsjann123
Copy link
Collaborator Author

Thx for the repro.

I see what the issue is now. nvfuser executor maps dtypes differently for number types. so I just need another condition to convert weak to strong type when output is a tensor proxy.

@jjsjann123
Copy link
Collaborator Author

but I'm surprised this is not part of opinfo tests. I'll add that.

@jjsjann123
Copy link
Collaborator Author

@kshitij12345 verified the fix on your repro. Added tests in opinfo as well.

Thanks a lot for pointing this out. 🙇

@mruberry
Copy link
Collaborator

mruberry commented Feb 3, 2025

@jjsjann123 Looks like the CI issues are related

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @jjsjann123!

@jjsjann123 jjsjann123 enabled auto-merge (squash) February 4, 2025 16:39
@jjsjann123 jjsjann123 requested a review from mruberry February 4, 2025 16:39
@jjsjann123
Copy link
Collaborator Author

code is good for review again. cc'ing @mruberry

@jjsjann123
Copy link
Collaborator Author

@jjsjann123 Looks like the CI issues are related

CI issue was because we try to run grad on the newly added scalar inputs, which doesn't have grad function.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Nice comments

@jjsjann123 jjsjann123 merged commit 40f7972 into main Feb 4, 2025
49 checks passed
@jjsjann123 jjsjann123 deleted the fix_1723 branch February 4, 2025 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

div with nvfuser returns incorrect dtype
4 participants