diff --git a/gpustat/core.py b/gpustat/core.py index e0a2804..2fcbf12 100644 --- a/gpustat/core.py +++ b/gpustat/core.py @@ -31,14 +31,14 @@ from gpustat import util -if util.hasNvidia(): - from gpustat import nvml - from gpustat.nvml import nvml as N - from gpustat.nvml import check_driver_nvml_version -else: +if util.hasAMD(): from gpustat import rocml as nvml from gpustat import rocml as N from gpustat.rocml import check_driver_nvml_version +else: + from gpustat import nvml + from gpustat.nvml import nvml as N + from gpustat.nvml import check_driver_nvml_version NOT_SUPPORTED = 'Not Supported' MB = 1024 * 1024 diff --git a/gpustat/util.py b/gpustat/util.py index d865654..e6b0067 100644 --- a/gpustat/util.py +++ b/gpustat/util.py @@ -104,9 +104,9 @@ def report_summary(self, concise=True): self._write('') -def hasNvidia(): +def hasAMD(): try: - subprocess.check_output('nvidia-smi') + subprocess.check_output('rocm-smi') return True except Exception: return False