From d31ca2cd7e3f979a3a8955ccbe5ac81fe9fb375d Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 19 Oct 2024 09:46:20 +0200 Subject: [PATCH] Use autocast depending on the version --- benchmarks/benchmark_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py index 570d1752a..3907408bd 100644 --- a/benchmarks/benchmark_optimizer.py +++ b/benchmarks/benchmark_optimizer.py @@ -20,9 +20,9 @@ torch_version = torch.__version__.split("+")[0] if version.parse(torch_version) >= version.parse("2.3.0"): - from torch.amp import GradScaler + from torch.amp import GradScaler, autocast else: - from torch.cuda.amp import GradScaler + from torch.cuda.amp import GradScaler, autocast @dataclass(frozen=True) @@ -115,7 +115,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: batch = torch.randint(0, len(X_train), (batch_size,)) - with torch.amp.autocast() if args.use_amp else nullcontext(): + with autocast() if args.use_amp else nullcontext(): loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device)) grad_scaler.scale(loss).backward()