Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

make the backward of differentiable float8 casts pass gradient as is #255

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented May 3, 2024

Summary:

Behavior before:

  • high precision to float8 in fw, float8 to high precision in bw
  • float8 to high precision in fw, high precision to float8 in bw if grad is a Float8Tensor, pass gradient unchanged otherwise

Behavior after:

  • high precision to float8 in fw, pass gradient unchanged in bw
  • float8 to high precision in fw, pass gradient unchanged in bw

Motivation for the new state:

  1. we want gradients to be in high precision unless specified otherwise by the float8 recipe, and the logic to specify grad casting to float8 before the matmul is better implemented elsewhere
  2. there is actually no logic change in this diff as the backward casts were not getting hit from existing code, this diff just makes the intended behavior clearer

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Behavior before:
* high precision to float8 in fw, float8 to high precision in bw
* float8 to high precision in fw, high precision to float8 in bw if grad
  is a Float8Tensor, pass gradient unchanged otherwise

Behavior after:
* high precision to float8 in fw, pass gradient unchanged in bw
* float8 to high precision in fw, pass gradient unchanged in bw

Motivation for the new state:
1. we want gradients to be in high precision unless specified otherwise
   by the float8 recipe, and the logic to specify grad casting to float8
   before the matmul is better implemented elsewhere
2. there is actually no logic change in this diff as the backward casts
   were not getting hit from existing code, this diff just makes the intended
   behavior clearer

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 3, 2024
@vkuzo vkuzo requested review from albanD, wanchaol and drisspg May 3, 2024 18:11
@vkuzo
Copy link
Contributor Author

vkuzo commented May 3, 2024

I noticed this logic was funky when copy-pastaing some of this code for MX prototyping, fixing here to clarify intent.

else:
return grad

if isinstance(x, DTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recalled the main reason we had that special handling is that torch.compile (specifically FakeTensor rule) can't take nested subclass rule automatically. If we remove this wondering where we would put such a logic?

cc @drisspg @bdhirsh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I understand the old state correctly:

  1. grad is never a Float8Tensor in practice, because we always have the matmul output high precision
  2. things already work without nested subclasses, which is why removing this function is fine

But, would be great to clarify ^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I checked the current workflow, we call:

  1. cast_to_float8_e4m3fn
  2. cast_to_float8_e5m2_bw
    and it seems none of these calls would call into this from_fp8_no_autograd

so this should be fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it looks like I'm wrong, the cast_to_float8_e4m3fn's backward would call this, but maybe it's a no-op even as of current state, as the output of the torch.scaled_mm in backward is fp32?

@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vkuzo merged this pull request in 605fc1d.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants