diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py index 570d1752a..3907408bd 100644 --- a/benchmarks/benchmark_optimizer.py +++ b/benchmarks/benchmark_optimizer.py @@ -20,9 +20,9 @@ torch_version = torch.__version__.split("+")[0] if version.parse(torch_version) >= version.parse("2.3.0"): - from torch.amp import GradScaler + from torch.amp import GradScaler, autocast else: - from torch.cuda.amp import GradScaler + from torch.cuda.amp import GradScaler, autocast @dataclass(frozen=True) @@ -115,7 +115,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: batch = torch.randint(0, len(X_train), (batch_size,)) - with torch.amp.autocast() if args.use_amp else nullcontext(): + with autocast() if args.use_amp else nullcontext(): loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device)) grad_scaler.scale(loss).backward()