-
Notifications
You must be signed in to change notification settings - Fork 20
make the backward of differentiable float8 casts pass gradient as is #255
Conversation
Summary: Behavior before: * high precision to float8 in fw, float8 to high precision in bw * float8 to high precision in fw, high precision to float8 in bw if grad is a Float8Tensor, pass gradient unchanged otherwise Behavior after: * high precision to float8 in fw, pass gradient unchanged in bw * float8 to high precision in fw, pass gradient unchanged in bw Motivation for the new state: 1. we want gradients to be in high precision unless specified otherwise by the float8 recipe, and the logic to specify grad casting to float8 before the matmul is better implemented elsewhere 2. there is actually no logic change in this diff as the backward casts were not getting hit from existing code, this diff just makes the intended behavior clearer Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
I noticed this logic was funky when copy-pastaing some of this code for MX prototyping, fixing here to clarify intent. |
else: | ||
return grad | ||
|
||
if isinstance(x, DTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I understand the old state correctly:
grad
is never aFloat8Tensor
in practice, because we always have the matmul output high precision- things already work without nested subclasses, which is why removing this function is fine
But, would be great to clarify ^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I checked the current workflow, we call:
- cast_to_float8_e4m3fn
- cast_to_float8_e5m2_bw
and it seems none of these calls would call into this from_fp8_no_autograd
so this should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually it looks like I'm wrong, the cast_to_float8_e4m3fn
's backward would call this, but maybe it's a no-op even as of current state, as the output of the torch.scaled_mm in backward is fp32?
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary:
Behavior before:
Behavior after:
Motivation for the new state:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: