-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing dtype promotion in where #1734
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, Current fix works because there is another where
later with tensor
inputs which correctly casts the output later.
I think we should also handle the following case (which still fails).
- Fusion with only
prims.where(cond, 1., 0.)
symbol
import thunder
import thunder.examine
import torch
def fn(c, x, y):
return torch.where(c, x, y)
x = 1.
y = 0.
print(fn(torch.rand(3, 3, device='cuda') > 0.5, x, y).dtype) # torch.float
jfn = thunder.jit(fn, nv_store_fusion_inputs=True)
print(jfn(torch.rand(3, 3, device='cuda') > 0.5, x, y).dtype) # torch.double
# trc = thunder.last_traces(jfn)[-1]
# print(trc)
# fusion_syms = thunder.examine.get_fusion_symbols(trc)
# bsym = fusion_syms[0]
# repro = thunder.examine.get_nvfuser_repro(trc, bsym.sym.name)
# print(repro)
Interesting! I'm surprised this case still fails with this fix. Maybe we need more tests cases for nvfuser and where, too |
Thx for the repro. I see what the issue is now. nvfuser executor maps dtypes differently for number types. so I just need another condition to convert weak to strong type when output is a tensor proxy. |
but I'm surprised this is not part of opinfo tests. I'll add that. |
@kshitij12345 verified the fix on your repro. Added tests in opinfo as well. Thanks a lot for pointing this out. 🙇 |
@jjsjann123 Looks like the CI issues are related |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @jjsjann123!
code is good for review again. cc'ing @mruberry |
CI issue was because we try to run grad on the newly added scalar inputs, which doesn't have grad function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes #1723
nvfuser type promotion rule for
where
is different from thunder. We could accidentally create output tensor in double as noted in #1723 (comment).