Skip to content

Commit

Permalink
fix grad scalar import on torch > 2.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 12, 2024
1 parent d20e810 commit 4c52dbf
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
from typing import Dict, Optional

import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
from packaging import version

torch_version = torch.__version__.split("+")[0]

if version.parse(torch_version) >= version.parse("2.3.0"):
from torch.amp import GradScaler as TorchGradScaler
from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state
else:
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state

from torch.optim import Optimizer as TorchOptimizer

import hivemind
Expand Down

0 comments on commit 4c52dbf

Please sign in to comment.