Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Remove fairscale tp/sp (#241)
Browse files Browse the repository at this point in the history
Summary:
Now that we have DTensor support remove existing fairscale implementations of TP/SP

Pull Request resolved: #241

Reviewed By: vkuzo

Differential Revision: D55164494

Pulled By: drisspg

fbshipit-source-id: 90ce3f23f370a8bf9b72256d108f592b9b278379
  • Loading branch information
drisspg authored and facebook-github-bot committed Mar 21, 2024
1 parent 88e9e50 commit 14da04f
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 393 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ float8_experimental/__pycache__/*
finetune/__pycache__/*
test/__pycache__/*
torch_compile_debug/*
tmp/*
test/tmp/*
benchmarks/data/*

# Distribution / packaging
Expand Down
213 changes: 0 additions & 213 deletions float8_experimental/tp_linear.py

This file was deleted.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ classifiers = [
]

dependencies = [
"torch >= 2.1",
"fairscale==0.4.13"
"torch >= 2.2",
]

[project.optional-dependencies]
test = [
"transformers==4.32.0",
"transformers==4.38.2",
"pandas >= 2.0",
"tqdm==4.66.2",
"fire==0.5.0",
Expand Down
1 change: 0 additions & 1 deletion test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pytest test/test_sam.py
pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_tp.sh
./test/test_dtensor.sh

echo "all tests successful"
2 changes: 1 addition & 1 deletion test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
torch.manual_seed(0)

# assumes user is running the script from /data/users/{user}/float8_experimental
data_dir = os.path.join(os.getcwd(), "tmp")
data_dir = os.path.join(os.path.dirname(__file__), "tmp")
input_fname = os.path.join(data_dir, "input.pt")
sd_in_fname = os.path.join(data_dir, "sd_in.pt")
sd_out_single_gpu_fname = os.path.join(data_dir, "sd_out_single_gpu.pt")
Expand Down
18 changes: 14 additions & 4 deletions test/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

import torch
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
Expand All @@ -20,20 +21,24 @@
from float8_experimental.float8_utils import compute_error
from transformers import SamModel

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


torch.manual_seed(0)


class TestFloat8SAMIntegrationTest:
@pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_encoder_fw_bw(self, data_dtype):
@pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear])
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
def test_encoder_fw_bw(self, data_dtype, linear_type):
model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda()
# print(model)

# for now just test the encoder to simplify things
encoder_ref = model.vision_encoder
encoder_fp8 = copy.deepcopy(encoder_ref)
swap_linear_with_float8_linear(encoder_fp8, Float8Linear, emulate=False)
swap_linear_with_float8_linear(encoder_fp8, linear_type, emulate=False)

# an image
# Note: bsz==4 or a larger power of 2 for this model is needed to
Expand All @@ -55,7 +60,12 @@ def test_encoder_fw_bw(self, data_dtype):
ref_name_to_grad = {
name: param.grad for name, param in encoder_ref.named_parameters()
}
sqnr_threshold = 1.0 if data_dtype == torch.float16 else -4

# Delayed scaling has less performant numerics
fudge_factor = 7.0 if linear_type == Float8Linear else 1.0
sqnr_threshold = -1.0 if data_dtype == torch.float16 else -4
sqnr_threshold *= fudge_factor

for name, param in encoder_fp8.named_parameters():
ref_grad = ref_name_to_grad[name]
cur_grad = param.grad
Expand Down
Loading

0 comments on commit 14da04f

Please sign in to comment.