From ec8b46cda737cb72c0769eba42341edf50111e22 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 18 Jul 2024 10:52:13 -0700 Subject: [PATCH] fixes to matmul and linear benchmarks (#320) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/320 for matmul benchmarks, unbreaks them - we need the scales to be fp32, not integers for linear benchmarks, aligns default settings to current best supported path (compile on, dynamic scaling) Reviewed By: awgu Differential Revision: D59877198 fbshipit-source-id: 092daaffeb0096f9fbd12ca407701bc3aa80c97c --- benchmarks/bench_linear_float8.py | 12 ++++++------ benchmarks/bench_matmul.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 5f8e4f9..967de57 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -91,13 +91,13 @@ def float8_pct_top_peak(self): def main( sweep_path: Optional[Path] = None, - compile: bool = False, + compile: bool = True, n_limit: Optional[int] = None, fast_accum_filter: Optional[bool] = None, shape_name_filter: Optional[str] = None, - scaling_type_x: str = "delayed", - scaling_type_w: str = "delayed", - scaling_type_dL_dY: str = "delayed", + scaling_type_x: str = "dynamic", + scaling_type_w: str = "dynamic", + scaling_type_dL_dY: str = "dynamic", ): device = "cuda" print(f"Compile is set to | {compile}") @@ -274,7 +274,7 @@ def wrapper(*args, **kwargs): def invoke_main() -> None: parser = argparse.ArgumentParser() parser.add_argument("-o", "--output_path", type=str, required=False) - parser.add_argument("--compile", action="store_true") + parser.add_argument("--disable_compile", action="store_true") parser.add_argument("-n", "--n_limit", type=int, required=False) parser.add_argument("--fast_accum_filter", type=bool, required=False) parser.add_argument("--shape_name_filter", type=str, required=False) @@ -292,7 +292,7 @@ def invoke_main() -> None: kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY main( output_path, - args.compile, + not args.disable_compile, args.n_limit, args.fast_accum_filter, args.shape_name_filter, diff --git a/benchmarks/bench_matmul.py b/benchmarks/bench_matmul.py index 967267d..6220670 100644 --- a/benchmarks/bench_matmul.py +++ b/benchmarks/bench_matmul.py @@ -101,8 +101,8 @@ def run(n_limit: Optional[int] = None): B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() def do_matmul(A, B): - scale_a = torch.tensor([1], device=device) - scale_b = torch.tensor([1], device=device) + scale_a = torch.tensor([1.0], device=device) + scale_b = torch.tensor([1.0], device=device) return torch._scaled_mm( A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False )