From c1487ef248bdd17be1875f4c39844a818aace093 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Tue, 16 Jul 2024 09:18:37 -0700
Subject: [PATCH 1/2] [TBD if for land] bring back torch.autograd.Function

Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
---
 float8_experimental/float8_linear.py | 102 ++++++++++++++++++++++++++-
 1 file changed, 101 insertions(+), 1 deletion(-)

diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py
index 7850738..787a54c 100644
--- a/float8_experimental/float8_linear.py
+++ b/float8_experimental/float8_linear.py
@@ -68,6 +68,101 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
         )
         scale.copy_(new_scale)
 
+# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
+# and modified to only support dynamic scaling
+@torch._dynamo.allow_in_graph
+class float8_linear(torch.autograd.Function):
+    """
+    Like F.linear, but with X and W in float8
+    """
+
+    @staticmethod
+    def forward(
+        ctx,
+        x_fp8,
+        w_fp8,
+        emulate: bool,
+        # TODO(this PR): split config into fwd/bwd
+        mm_config: ScaledMMConfig,
+    ):
+        ctx.save_for_backward(x_fp8, w_fp8)
+        ctx.emulate = emulate
+        ctx.mm_config = mm_config
+        # orig_shape = x_fp8._data.shape
+        orig_shape = x_fp8.shape
+        # x_fp8_reshaped = Float8Tensor(
+        #     x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config
+        # )
+        x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
+
+        # w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config)
+        w_fp8_t = w_fp8.t()
+
+        res_bits = torch.mm(
+            x_fp8_reshaped, w_fp8_t
+        )
+        res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
+        return res_bits
+
+    @staticmethod
+    def backward(ctx, go_fp8):
+        x_fp8, w_fp8 = ctx.saved_tensors
+        emulate = ctx.emulate
+        mm_config = ctx.mm_config
+
+        go_fp8_orig_shape = go_fp8.shape
+        # go_fp8_reshaped = Float8Tensor(
+        #     go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
+        #     go_fp8._scale,
+        #     go_fp8._orig_dtype,
+        #     mm_config,
+        # )
+        go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])
+
+        # w_fp8_t_c_t = Float8Tensor(
+        #     w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config
+        # )
+        w_fp8_t_c_t = w_fp8.t().contiguous().t()
+
+        #
+        # calculate dL/dX
+        #
+        dL_dX = torch.mm(
+            go_fp8_reshaped,
+            w_fp8_t_c_t,
+        )
+        dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
+
+        # x_fp8_orig_shape = x_fp8._data.shape
+        x_fp8_orig_shape = x_fp8.shape
+        # x_fp8_reshaped_t_c = Float8Tensor(
+        #     x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
+        #     x_fp8._scale,
+        #     x_fp8._orig_dtype,
+        #     mm_config,
+        # )
+        x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous()
+
+        # go_fp8_reshaped_t_c_t = Float8Tensor(
+        #     go_fp8_reshaped._data.t().contiguous().t(),
+        #     go_fp8_reshaped._scale,
+        #     go_fp8_reshaped._orig_dtype,
+        #     mm_config,
+        # )
+        go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t()
+
+        #
+        # calculate dL/dW
+        #
+        dL_dW = torch.mm(
+            x_fp8_reshaped_t_c,
+            go_fp8_reshaped_t_c_t,
+        )
+        dL_dW = dL_dW.t()
+
+        empty_grads = None, None, None, None, None, None, None, None, None
+        return dL_dX, dL_dW, *empty_grads
+
 
 @torch._dynamo.allow_in_graph
 class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@@ -394,7 +489,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
         x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
         w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
 
-        y = torch.matmul(x_fp8, w_fp8.t())
+        if not self.has_any_delayed_scaling:
+            emulate = False
+            mm_config = self.forward_config
+            y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config)
+        else:
+            y = torch.matmul(x_fp8, w_fp8.t())
 
         # Cast gradY to float8_e5m2 during backward
         y = self.cast_y_to_float8_in_bw(y)

From 8505776f89131152887f1d0f27df731f81c346ab Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Tue, 16 Jul 2024 10:01:30 -0700
Subject: [PATCH 2/2] Update on "[TBD if for land] bring back
 torch.autograd.Function"

Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

```

# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
# and modified to only support dynamic scaling
#
# Why do we want a torch.autograd.Function here? Vasiliy's opinion is that
# as we add more scaling granularities, keeping the scaling code close to Float8Linear
# will be really useful for readability and debuggability of numerics.
#
# For example, a future PR to add rowwise scaling could do
#
#   # forward
#   x_bf16 = ...
#   if scaling_granularity == ScalingGranularity.PER_TENSOR:
#       # we can scale the same way for fwd/bwd
#       x_maybe_fp8 = to_fp8(...)
#   else:
#       assert scaling_granularity == ScalingGranularity.PER_ROW:
#       # defer scaling to float8_mm
#       x_maybe_fp8 = x_bf16
#
#   # repeat for w
#
#   y_bf16 = float8_mm(x_maybe_fp8, w_maybe_fp8)
#
#   Requirements for float8_mm
#   - composes with DTensor, compile, autograd
#   - readable/debuggable
#
#   Option 1 (this PR): float8_mm is a torch.autograd.Function
#   - pros
#   - cons
#   Option 2 (current code without this PR): float8_mm is an override of torch.mm
#   - pros
#   - cons
#

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
---
 float8_experimental/float8_linear.py | 53 ++++------------------------
 1 file changed, 7 insertions(+), 46 deletions(-)

diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py
index 787a54c..16c1257 100644
--- a/float8_experimental/float8_linear.py
+++ b/float8_experimental/float8_linear.py
@@ -68,12 +68,13 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
         )
         scale.copy_(new_scale)
 
+
 # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
 # and modified to only support dynamic scaling
 @torch._dynamo.allow_in_graph
-class float8_linear(torch.autograd.Function):
+class float8_mm(torch.autograd.Function):
     """
-    Like F.linear, but with X and W in float8
+    Like torch.mm, but with X and W in float8
     """
 
     @staticmethod
@@ -81,47 +82,24 @@ def forward(
         ctx,
         x_fp8,
         w_fp8,
-        emulate: bool,
-        # TODO(this PR): split config into fwd/bwd
-        mm_config: ScaledMMConfig,
     ):
         ctx.save_for_backward(x_fp8, w_fp8)
-        ctx.emulate = emulate
-        ctx.mm_config = mm_config
-        # orig_shape = x_fp8._data.shape
         orig_shape = x_fp8.shape
-        # x_fp8_reshaped = Float8Tensor(
-        #     x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config
-        # )
         x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
 
-        # w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config)
         w_fp8_t = w_fp8.t()
 
-        res_bits = torch.mm(
-            x_fp8_reshaped, w_fp8_t
-        )
+        res_bits = torch.mm(x_fp8_reshaped, w_fp8_t)
         res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
         return res_bits
 
     @staticmethod
     def backward(ctx, go_fp8):
         x_fp8, w_fp8 = ctx.saved_tensors
-        emulate = ctx.emulate
-        mm_config = ctx.mm_config
 
         go_fp8_orig_shape = go_fp8.shape
-        # go_fp8_reshaped = Float8Tensor(
-        #     go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
-        #     go_fp8._scale,
-        #     go_fp8._orig_dtype,
-        #     mm_config,
-        # )
         go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])
 
-        # w_fp8_t_c_t = Float8Tensor(
-        #     w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config
-        # )
         w_fp8_t_c_t = w_fp8.t().contiguous().t()
 
         #
@@ -133,22 +111,9 @@ def backward(ctx, go_fp8):
         )
         dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
 
-        # x_fp8_orig_shape = x_fp8._data.shape
         x_fp8_orig_shape = x_fp8.shape
-        # x_fp8_reshaped_t_c = Float8Tensor(
-        #     x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
-        #     x_fp8._scale,
-        #     x_fp8._orig_dtype,
-        #     mm_config,
-        # )
         x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous()
 
-        # go_fp8_reshaped_t_c_t = Float8Tensor(
-        #     go_fp8_reshaped._data.t().contiguous().t(),
-        #     go_fp8_reshaped._scale,
-        #     go_fp8_reshaped._orig_dtype,
-        #     mm_config,
-        # )
         go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t()
 
         #
@@ -160,7 +125,7 @@ def backward(ctx, go_fp8):
         )
         dL_dW = dL_dW.t()
 
-        empty_grads = None, None, None, None, None, None, None, None, None
+        empty_grads = (None,)
         return dL_dX, dL_dW, *empty_grads
 
 
@@ -489,12 +454,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
         x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
         w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
 
-        if not self.has_any_delayed_scaling:
-            emulate = False
-            mm_config = self.forward_config
-            y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config)
-        else:
-            y = torch.matmul(x_fp8, w_fp8.t())
+        # y = float8_mm.apply(x_fp8, w_fp8)
+        y = float8_mm.apply(x_fp8, w_fp8)
 
         # Cast gradY to float8_e5m2 during backward
         y = self.cast_y_to_float8_in_bw(y)