This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
make all 3 gemms in Float8Linear support configurability, not user facing #315
Closed
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
4895a52
[wip] make all 3 gemms in Float8Linear configurable
vkuzo e3545db
Update on "[wip] make all 3 gemms in Float8Linear configurable"
vkuzo bfe17f9
Update on "[wip] make all 3 gemms in Float8Linear configurable"
vkuzo 57a4b32
Update on "[wip] make all 3 gemms in Float8Linear configurable"
vkuzo 34934ce
Update on "make all 3 gemms in Float8Linear support configurability, …
vkuzo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,8 @@ | |
|
||
from float8_experimental.float8_tensor import ( | ||
Float8Tensor, | ||
GemmInputRole, | ||
LinearMMConfig, | ||
ScaledMMConfig, | ||
to_fp8_no_autograd, | ||
) | ||
|
@@ -85,12 +87,12 @@ def forward( | |
fp8_scale_dL_dY, | ||
scale_fn_name, | ||
is_amax_initialized, | ||
mm_config: ScaledMMConfig, | ||
linear_mm_config: LinearMMConfig, | ||
): | ||
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) | ||
ctx.scale_fn_name = scale_fn_name | ||
ctx.is_amax_initialized = is_amax_initialized | ||
ctx.mm_config = mm_config | ||
ctx.linear_mm_config = linear_mm_config | ||
return tensor | ||
|
||
@staticmethod | ||
|
@@ -113,7 +115,11 @@ def backward(ctx, go): | |
fp8_amax_dL_dY.fill_(tensor_to_amax(go)) | ||
|
||
res = to_fp8_no_autograd( | ||
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config | ||
go, | ||
fp8_scale_dL_dY, | ||
e5m2_dtype, | ||
linear_mm_config=ctx.linear_mm_config, | ||
gemm_input_role=GemmInputRole.DL_DY, | ||
) | ||
empty_grads = None, None, None, None, None, None | ||
return res, *empty_grads | ||
|
@@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs): | |
|
||
self.create_buffers() | ||
|
||
# Defines the behavior of the matmul in the forward and backward pass | ||
self.forward_config = ScaledMMConfig( | ||
emulate, True if not emulate else False, False, config.pad_inner_dim | ||
) | ||
self.backward_config = ScaledMMConfig( | ||
emulate, False, False, config.pad_inner_dim | ||
# TODO(future): user level configuration of gemms | ||
self.linear_mm_config = LinearMMConfig( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [I think this might be another stylistic thing so no need to change]: I think I would actually make this a func and then super document it. Its not very clear reading this what everything does so I would clearly explain in that func the exact recipe that we choose by default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this isn't user facing, so we can clean up at any time |
||
# x | ||
ScaledMMConfig( | ||
emulate, True if not emulate else False, False, config.pad_inner_dim | ||
), | ||
# w | ||
ScaledMMConfig( | ||
emulate, True if not emulate else False, False, config.pad_inner_dim | ||
), | ||
# dL_dY | ||
ScaledMMConfig(emulate, False, False, config.pad_inner_dim), | ||
) | ||
|
||
# Note: is_amax_initialized is not a buffer to avoid data dependent | ||
|
@@ -308,11 +320,12 @@ def cast_x_to_float8( | |
self.fp8_scale_x, | ||
e4m3_dtype, | ||
self.fp8_amax_x, | ||
self.forward_config, | ||
linear_mm_config=self.linear_mm_config, | ||
gemm_input_role=GemmInputRole.X, | ||
) | ||
else: | ||
assert self.scaling_type_x is TensorScalingType.DYNAMIC | ||
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config) | ||
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) | ||
return x_fp8 | ||
|
||
def cast_w_to_float8( | ||
|
@@ -339,14 +352,17 @@ def cast_w_to_float8( | |
self.fp8_scale_w, | ||
e4m3_dtype, | ||
self.fp8_amax_w, | ||
self.forward_config, | ||
linear_mm_config=self.linear_mm_config, | ||
gemm_input_role=GemmInputRole.W, | ||
) | ||
else: | ||
assert self.scaling_type_w is TensorScalingType.DYNAMIC | ||
if isinstance(self.weight, Float8Tensor): # cast by FSDP | ||
w_fp8 = self.weight | ||
else: | ||
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) | ||
w_fp8 = cast_to_float8_e4m3_dynamic( | ||
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W | ||
) | ||
return w_fp8 | ||
|
||
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: | ||
|
@@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: | |
self.fp8_scale_dL_dY, | ||
scale_fn_name, | ||
self.is_amax_initialized, | ||
self.backward_config, | ||
self.linear_mm_config, | ||
) | ||
else: | ||
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC | ||
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) | ||
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) | ||
return y | ||
|
||
def float8_pre_forward(self, x): | ||
|
@@ -457,7 +473,7 @@ def from_float( | |
new_mod.weight = torch.nn.Parameter( | ||
WeightWithDynamicFloat8CastTensor( | ||
new_mod.weight, | ||
new_mod.forward_config, | ||
new_mod.linear_mm_config, | ||
) | ||
) | ||
else: | ||
|
@@ -468,7 +484,7 @@ def from_float( | |
new_mod.fp8_amax_w, | ||
new_mod.fp8_amax_history_w, | ||
new_mod.fp8_scale_w, | ||
new_mod.forward_config, | ||
new_mod.linear_mm_config, | ||
new_mod.is_amax_initialized, | ||
) | ||
) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe no default value for gemm role?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can clean up in a separate PR, there is extra complexity because we'd need to change the argument order