Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add test to fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 18, 2024
1 parent 177173a commit 2ffcbe9
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 28 deletions.
2 changes: 1 addition & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def forward(
emulate=emulate,
)
if recompute_float8_weight:
# This should be set to True when using traditional fsdp to avoid saving
# This should be set to True when using traditional fsdp to avoid
# saving the unsharded weight for backwards
ctx.save_for_backward(
x_fp8, original_weight, weight_scale, weight_amax_buffer
Expand Down
55 changes: 45 additions & 10 deletions test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,22 @@ def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
def get_model(
K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False
):
m = nn.Sequential(
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
nn.Linear(N, N, dtype=base_dtype),
nn.ReLU(),
)
if is_fp8:
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
swap_linear_with_float8_linear(
m,
Float8Linear,
emulate=emulate,
recompute_weight_cast=recompute_weight_cast,
)
return m


Expand All @@ -81,10 +88,15 @@ def fsdp_main(rank, world_size, args):

# TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
# We can investigate and fix it later.
is_fp8, emulate, base_dtype, compile, fullgraph = args
model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to(
rank
)
is_fp8, emulate, base_dtype, compile, fullgraph, recompute_weight_cast = args
model = get_model(
K,
N,
is_fp8=is_fp8,
emulate=emulate,
base_dtype=base_dtype,
recompute_weight_cast=recompute_weight_cast,
).to(rank)
model.load_state_dict(torch.load(sd_in_fname))
# To compile FSDP, we need use_orig_params to True
model = FSDP(model, use_orig_params=True)
Expand Down Expand Up @@ -148,7 +160,13 @@ def forward_backward(model):
cleanup()


def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False):
def run(
mode: str,
is_fp8: bool,
compile_fsdp: bool = False,
fullgraph: bool = False,
recompute_weight_cast: bool = False,
):
print(f"Mode: {mode}".center(100, "-"))
base_dtype = torch.bfloat16
if not os.path.exists(data_dir):
Expand All @@ -169,15 +187,25 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F
# generate reference input
ref_input = torch.randn(B, M, K).cuda().to(base_dtype)
model = get_model(
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
K,
N,
is_fp8=is_fp8,
emulate=emulate,
base_dtype=base_dtype,
recompute_weight_cast=recompute_weight_cast,
).cuda()
torch.save(ref_input, input_fname)
torch.save(model.state_dict(), sd_in_fname)

elif mode == "single_gpu":
ref_input = torch.load(input_fname).to(base_dtype)
model = get_model(
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
K,
N,
is_fp8=is_fp8,
emulate=emulate,
base_dtype=base_dtype,
recompute_weight_cast=recompute_weight_cast,
).cuda()
model.load_state_dict(torch.load(sd_in_fname))
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
Expand All @@ -199,7 +227,14 @@ def forward_backward():
elif mode == "fsdp":
WORLD_SIZE = torch.cuda.device_count()
# We only compile for fsdp, and compare the numerics with signle-gpu no-compile
args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph)
args = (
is_fp8,
emulate,
base_dtype,
compile_fsdp,
fullgraph,
recompute_weight_cast,
)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

elif mode == "analyze":
Expand Down
35 changes: 26 additions & 9 deletions test/test_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
set -e

launch() {
echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH"
echo "Launching test with the following configuration:"
echo "IS_FP8: $IS_FP8"
echo "compile_fsdp: $COMPILE"
echo "fullgraph: $FULLGRAPH"
echo "recompute_weight_cast: $RECOMPUTE"

# generate the test data
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
echo "Success: ✅"

# generate single GPU model output and updated state dict
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
echo "Success: ✅"

# generate FSDP model output and updated state dict
Expand All @@ -20,19 +24,32 @@ launch() {
# the NCCL_NET setting is to work around transient issues on a
# specific host (`devgpu001.nha2`)
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE

# compare the outputs and state dicts and verify equivalence
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH --recompute_weight_cast $RECOMPUTE
echo "Success: ✅"

echo "✅ All Tests Passed ✅"
}

# IS_FP8, COMPILE, FULLGRAPH
for i in False,False,False True,False,False True,True,False
# Loop over different combinations of settings
for i in False,False,False,False \
True,False,False,False \
True,True,False,False \
True,False,False,True \
True,True,False,True
do
IFS=","; set -- $i;
IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3
# Split the string into variables
IFS=","
set -- $i

# Assign each variable to a more descriptive name
IS_FP8=$1
COMPILE=$2
FULLGRAPH=$3
RECOMPUTE=$4

# Launch the test with the current settings
launch
done
27 changes: 19 additions & 8 deletions test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
def get_model(
K, N, is_fp8, emulate, base_dtype=torch.float32, recompute_weight_cast: bool = False
):
m = nn.Sequential(
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
)
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
swap_linear_with_float8_linear(
m, Float8Linear, emulate=emulate, recompute_weight_cast=recompute_weight_cast
)
return m


Expand All @@ -63,7 +67,7 @@ def fsdp_main(rank, world_size, args):
setup(rank, world_size)
torch.cuda.set_device(rank)

(emulate,) = args
(emulate, recompute_weight_cast) = args

# composability of torch.compile + FSDP + autocast + Float8Linear
# as fo 2023-12-30
Expand All @@ -81,9 +85,14 @@ def fsdp_main(rank, world_size, args):
# things work e2e. Note that FSDP does not support full-graph compile
# regardless of float8.

model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to(
rank
)
model = get_model(
K,
N,
is_fp8=True,
emulate=emulate,
base_dtype=torch.bfloat16,
recompute_weight_cast=recompute_weight_cast,
).to(rank)

# To compile FSDP, we need use_orig_params to True
model = FSDP(model, use_orig_params=True)
Expand All @@ -102,7 +111,8 @@ def fsdp_main(rank, world_size, args):
sync_float8_func(model)
optimizer.step()

print("done!")
if rank == 0:
print("Success: ✅")
cleanup()


Expand All @@ -119,7 +129,8 @@ def run():
emulate = True

WORLD_SIZE = torch.cuda.device_count()
args = (emulate,)
recompute_weight_cast = True
args = (emulate, recompute_weight_cast)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)


Expand Down

0 comments on commit 2ffcbe9

Please sign in to comment.