This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added eager fp8 all-gather for dynamic scaling (subclass) (#247)
Summary: Pull Request resolved: #247 This PR requires pytorch/pytorch#122908. This PR adds an option for FSDP fp8 all-gather with FSDP2, focusing on correctness first. - We add a config option `enable_fsdp_fp8_all_gather: bool = False` in `float8_experimental.config` that can be set to `True` to enable the new behavior. - We only support `Float8DynamicLinear` for now. - For the fp8 all-gather, we make `Float8DynamicLinear.weight` to be a tensor subclass `Float8DynamicLinearWeightTensor`. This can be thought of as a `torch.Tensor` weight that knows how to cast itself to fp8. Since the cast logic is specific to dynamic scaling (i.e. it would not be the exact same for delayed scaling), I preferred to simply name the subclass `Float8DynamicLinearWeightTensor`. - The subclass is a `__torch_dispatch__` wrapper subclass since that is the kind of subclass most compatible with `torch.compile`. - The subclass defines `fsdp_pre_all_gather()` and `fsdp_post_all_gather()` to implement the fp8 all-gather. - We include unit testing that covers eager-mode correctness and memory usage. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D56192107 Pulled By: awgu fbshipit-source-id: b9e36fde0e39ca59658ddef86932a0b1ed2ca461
- Loading branch information
1 parent
31877bb
commit ac065d0
Showing
8 changed files
with
691 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import contextlib | ||
from typing import List, Type | ||
|
||
import float8_experimental.config as config | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from float8_experimental.float8_linear import Float8Linear | ||
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history | ||
|
||
|
||
def check_parity_no_mp( | ||
test_cls, | ||
ref_model: nn.Module, | ||
ref_optim: torch.optim.Optimizer, | ||
fsdp_model: nn.Module, | ||
fsdp_optim: torch.optim.Optimizer, | ||
local_inp: torch.Tensor, | ||
module_cls: Type, | ||
): | ||
for iter_idx in range(10): | ||
losses: List[torch.Tensor] = [] | ||
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): | ||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||
losses.append(model(local_inp).sum()) | ||
losses[-1].backward() | ||
if model is ref_model: | ||
for param in model.parameters(): | ||
dist.all_reduce(param.grad) | ||
param.grad.div_(dist.get_world_size()) | ||
if module_cls is Float8Linear: | ||
sync_float8_amax_and_scale_history(model) | ||
optim.step() | ||
test_cls.assertEqual(losses[0], losses[1]) | ||
|
||
|
||
def check_parity_bf16_mp( | ||
test_cls, | ||
ref_model: nn.Module, | ||
ref_model_bf16: nn.Module, | ||
ref_optim: torch.optim.Optimizer, | ||
fsdp_model: nn.Module, | ||
fsdp_optim: torch.optim.Optimizer, | ||
local_inp: torch.Tensor, | ||
module_cls: Type, | ||
): | ||
for iter_idx in range(10): | ||
losses: List[torch.Tensor] = [] | ||
for model, optim in ( | ||
(ref_model_bf16, ref_optim), | ||
(fsdp_model, fsdp_optim), | ||
): | ||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||
losses.append(model(local_inp).sum()) | ||
losses[-1].backward() | ||
if model is ref_model_bf16: | ||
for param_bf16, param_fp32 in zip( | ||
ref_model_bf16.parameters(), ref_model.parameters() | ||
): | ||
dist.all_reduce(param_bf16.grad) | ||
param_bf16.grad.div_(dist.get_world_size()) | ||
param_fp32.grad = param_bf16.grad.float() | ||
param_bf16.grad = None | ||
if module_cls is Float8Linear: | ||
sync_float8_amax_and_scale_history(model) | ||
optim.step() | ||
for param_fp32, param_bf16 in zip( | ||
ref_model.parameters(), ref_model_bf16.parameters() | ||
): | ||
param_bf16.detach().copy_(param_fp32) | ||
test_cls.assertEqual(losses[0], losses[1]) | ||
|
||
|
||
@contextlib.contextmanager | ||
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): | ||
prev = config.enable_fsdp_fp8_all_gather | ||
dist.barrier() | ||
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather | ||
try: | ||
yield | ||
finally: | ||
dist.barrier() | ||
config.enable_fsdp_fp8_all_gather = prev |
Oops, something went wrong.