You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are calling nvfuser where on (boolean_tv, scalar_1, scalar_2).
The type promotion/inference logic is that, we determine the output dtype of where based on input[1] and input[2].
We are producing output TV from this operation, but since all scalar types are in double, nvfuser generates a double tensor as output, instead of a float32 as the thunder scripts. (since thunder sees input scalar as in float).
Tracking the issue for user experience. In the issue comment here.
The to translate a user program like this vvv
# t18 = prims.where(t17, -0.0, 0.0) # t18: "cuda:0 f32[4, 2, 3]"
We are calling nvfuser
where
on (boolean_tv, scalar_1, scalar_2).The type promotion/inference logic is that, we determine the output dtype of where based on input[1] and input[2].
We are producing output TV from this operation, but since all scalar types are in double, nvfuser generates a double tensor as output, instead of a float32 as the thunder scripts. (since thunder sees input scalar as in float).
This isn't blocking, as we can always patch the executor in thunder for explicit type inference: Lightning-AI/lightning-thunder#1734
Per our thunder developer's request, we still want an issue to track this so we no longer needed these WAR on user of nvfuser.
The text was updated successfully, but these errors were encountered: