-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When I install optax, I am no longer able to use the GPU #1144
Comments
Hello @Alessandro-Castelli,
|
Thank you, @vroulet .
|
I tried creating two separate conda environments: one where I use TensorFlow to download the dataset and another where I install PennyLane, JAX, JAXlib, and Optax to train the model, but the error still occurs.
|
At this point, I think that the problem is optax. Name: jax Name: jaxlib Name: optax |
I really don't think so. Just look at the code in optax. It's quite a lightweight library not related to any cuda gpu functionality. There could have been bumped imports but the above version of optax and jax jaxlib seem good. I cannot reproduce the error you're mentioning as tensorflow 0.9 is not seem available to me locally, and anyway I don't have a gpu. The error clearly points out to jaxlib not optax. |
@vroulet I’ll explain why I think it’s optax. Basically, in my initial code, I was using jax and jaxlib 0.4.23, pennylane, and tensorflow 2.9.0, and I didn’t have any issues installing those versions. At some point, I needed more powerful optimizers like Adam to do some tasks, and that’s when I started using optax. Only from that moment, I encountered the issue: 2024-11-28 17:39:56.095832: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu Maybe it’s optax that doesn’t get along with tensorflow. Another thing is that when I install optax, it updates jax and jaxlib to version 0.4.30, so maybe that’s the problem. |
I feel the pain, I've been there - versioning between cuda/jax/tensorflow is a mess. I would suggest having a different virtualenv for jax-based and TF-based projects if you can .... |
Hello @fabianp, I tried to do it, but I think that the real problem is the versioning between JAX and Optax. I tried many different Optax versions, but I didn't resolve my problem. |
have you tried installing optax with |
Yes, but Optax has additional dependencies, and following this approach doesn't seem to work for Optax. |
"I have jax 0.4.23. What happens is that when I install optax with the command
pip install optax
, I get an error message saying 'An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. [CpuDevice(id=0)]'.This error only occurs after I install optax. What version of optax is compatible with my version of jax and tensorflow 2.9.0?"
Name: jax
Version: 0.4.23
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by:
Name: jaxlib
Version: 0.4.23+cuda11.cudnn86
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by:
(af) acastelli@leonardo:/media/HDD/acastelli/test2$ `
The text was updated successfully, but these errors were encountered: