-
Notifications
You must be signed in to change notification settings - Fork 377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MrVI slowdown due to JAX compilation update #3179
Comments
@justjhong lmk if you want me to upper bound jax for now and to which version. |
Hi @ori-kron-wis, thanks for checking. I took some time this morning to try to debug it but was not able to find a solution. |
For now pinning jax<0.4.36. Potentially related to jax-ml/jax#26162. Check again when this is adressed. Leaving this open as pinning circumvents it but might create issues in the near future. |
With recent updates to JAX, MrVI trains significantly slower than before. We suspect it is due to the new AOT compilation strategy (https://jax.readthedocs.io/en/latest/aot.html).
Any basic training with MrVI with a fresh install. Reproduced by @PierreBoyeau and myself.
The text was updated successfully, but these errors were encountered: