From 855795c2d8036a5e4d753b3fe5a0d31c4487d498 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 26 Dec 2023 09:54:16 -0800 Subject: [PATCH] [wip] hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/config.py | 8 ++++ .../dynamic_linear/dynamic_float8_linear.py | 40 ++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 0f8b96be..d74c9166 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -16,3 +16,11 @@ # according to their microbatching/pipeline parallel setup. # Note: this is currently a global flag for simplicity and dynamo performance. weight_cache_enabled = False + +# +# Other +# + +# If True, dynamic linear uses hooks for activation casting +dynamic_use_activation_hooks = True +# dynamic_use_activation_hooks = False diff --git a/float8_experimental/dynamic_linear/dynamic_float8_linear.py b/float8_experimental/dynamic_linear/dynamic_float8_linear.py index f0c6a239..16dd4557 100644 --- a/float8_experimental/dynamic_linear/dynamic_float8_linear.py +++ b/float8_experimental/dynamic_linear/dynamic_float8_linear.py @@ -11,6 +11,7 @@ from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated +import float8_experimental.config as config class NoopFwToFloat8E5M2Bw(torch.autograd.Function): @@ -38,6 +39,20 @@ def backward(ctx, gradY): None, ) +def cast_x_to_float8_e4m3fn_pre_hook(module, args): + """ + Hook to cast the incoming activation to `torch.float8_e4m3fn` + """ + return module.cast_to_float8(args[0]) + + +def cast_dldy_to_float8_e5m2_forward_hook(module, args, output): + """ + Hook to cast the incoming gradient to `torch.float8_e5m2` + """ + new_output = NoopFwToFloat8E5M2Bw.apply(output, module.emulate) + return new_output + class Float8DynamicLinear(torch.nn.Linear): """ @@ -48,9 +63,16 @@ class Float8DynamicLinear(torch.nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_weight_tag() + self.use_activation_hooks = config.dynamic_use_activation_hooks def forward(self, x): - x_fp8 = self.cast_to_float8(x) + # cast x to float8_e4m3fn + if self.use_activation_hooks: + x_fp8 = x + else: + x_fp8 = self.cast_to_float8(x) + + # cast w to float8_e4m3fn if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast w_fp8 = self._w_fp8 else: @@ -59,7 +81,10 @@ def forward(self, x): y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) # Cast gradY to float8_e5m2 during backward - y = self.cast_to_float8e5m2_bw(y) + if self.use_activation_hooks: + pass + else: + y = self.cast_to_float8e5m2_bw(y) return y @@ -69,6 +94,7 @@ def add_weight_tag(self): self.weight._is_fp8_weight = True def cast_to_float8(self, inpt_tensor): + # TODO rename this function to clarify e4m3 scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate @@ -92,4 +118,14 @@ def from_float(cls, mod, emulate: bool = False): new_mod.bias = mod.bias new_mod.emulate = emulate new_mod.add_weight_tag() + + new_mod.use_activation_hooks = config.dynamic_use_activation_hooks + if new_mod.use_activation_hooks: + # install the hooks + # TODO(future): figure out why using backward pre-hooks does not + # work here: + # 1. repro code: https://gist.github.com/vkuzo/27a3f6ca48e50ba1134b077f0dba254c + # 2. repro output: https://gist.github.com/vkuzo/728eae9dcc627e130829d122daa982e7 + new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) + new_mod.register_forward_hook(cast_dldy_to_float8_e5m2_forward_hook) return new_mod