Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

bring back torch.autograd.Function #316

Open
wants to merge 5 commits into
base: gh/vkuzo/29/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,54 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does the structure work out to put this in float8 ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in how things look after this PR it would make sense, but might be good to see how the code looks after we add different granularities and the if/else branches on when to convert to lower precision. Maybe we can revisit then?

@torch._dynamo.allow_in_graph
class manual_float8_mm(torch.autograd.Function):
"""
Like torch.mm, but with X and W in float8
"""

@staticmethod
def forward(
ctx,
x_fp8,
w_fp8_t,
):
ctx.save_for_backward(x_fp8, w_fp8_t)
orig_shape = x_fp8.shape
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
vkuzo marked this conversation as resolved.
Show resolved Hide resolved
res_bits = torch.mm(x_fp8_reshaped, w_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, go_fp8):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: align go_fp8 / other naming to the other PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do that in separate PRs, since not user facing. Just keeping things small.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if that changes the size of the PR much but sure thats fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably just a style difference on how to sequence the renames, either is ok IMO

x_fp8, w_fp8_t = ctx.saved_tensors

go_fp8_orig_shape = go_fp8.shape
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])

# calculate dL/dX
dL_dX = torch.mm(
go_fp8_reshaped,
w_fp8_t.t(),
)
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])

x_fp8_orig_shape = x_fp8.shape
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1])

# calculate dL/dW
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `dL_dW_t = x_fp8_t @ go_fp8_reshaped`
dL_dW = torch.mm(
go_fp8_reshaped.t(),
x_fp8_reshaped,
)

return dL_dX, dL_dW.t()


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""
Expand Down Expand Up @@ -410,7 +458,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
y = manual_float8_mm.apply(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y)
Expand Down
4 changes: 3 additions & 1 deletion float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def choose_scaled_mm_config(
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
return a_linear_mm_config.dL_dX
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
elif (a_role is GemmInputRole.X and b_role is GemmInputRole.DL_DY) or (
a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X
):
assert (
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"
Expand Down
Loading