Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[not for land] enumerate breakages with module hooks + compile #270

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def backward(ctx, gradY):
)
return fp8_tensor, None

def forward_pre_hook(mod, x):
x = cast_to_float8_e4m3fn(x[0], mod.forward_config)
return x

def forward_post_hook(mod, x, y):
y = cast_to_float8_e5m2_bw(y, mod.backward_config)
return y

class Float8DynamicLinear(torch.nn.Linear):
"""
Expand All @@ -62,14 +69,14 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, x):
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
def forward(self, x_fp8):
# x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_bw(y, self.backward_config)
# y = cast_to_float8_e5m2_bw(y, self.backward_config)
return y

@classmethod
Expand Down Expand Up @@ -97,6 +104,8 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
else:
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.register_forward_pre_hook(forward_pre_hook)
new_mod.register_forward_hook(forward_post_hook)
return new_mod


Expand Down
Loading