You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am a graduate student using MJX for robot locomotion. With the recent release of Mujoco playground, I am interested in using MJX over IsaacGym, but a huge problem I run into is that I cannot effectively debug or iterate small code changes.
It takes more than 5 minutes to trace/compile my jitted reset and step functions, which are functionally very similar to those found here. I have already tried all the typical advice to reduce Jax compilation time (e.g. using jax control flow instead of native python), but the main bottleneck is inside the mjx.forward and mjx.step functions.
#1273 touches on this issue, but is somewhat old and the advice of "use Mujoco for development and MJX for actual training" is not very satisfactory. One interesting point from this issue was:
In our roadmap are some tools for AOT'ing JIT / caching JIT output to disk, so that you can rapidly develop in MJX without waiting for JIT every time you restart your process... no ETA on that work though, but that will be a big help for sure.
Is this still on the roadmap?
Feature Suggestion
How plausible is it to have a "debug mode" (i.e. using numpy, and doing things on cpu) under the hood for people to easily iterate on development? I am thinking something that allows users to still be able to use the mjx API (i.e. call mjx.forward and mjx.step) but it does not use jax arrays under the hood.
I am not too familiar with the internals of MJX, but considering jax.numpy and numpy share the same API, could we get most of the way there by just creating an alias for numpy? I.e. import numpy as jp.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
@mpiseno We hear you and we feel the same pain! We have some pretty neat improvements to JIT speed and more granular JIT caching that we will share in a few weeks' time.
I'll come back and leave a comment here when we're ready to share more.
p.s. we actually did use to have a way of switching between JAX and numpy in Brax, called jumpy:
You're welcome to play with the idea, but let me warn you that while at first glance it seems elegant, the idea of swapping numpy for JAX is full of really gnarly gotchas. The number of errors and edge cases ultimately made it not worth it in our eyes.
The feature, motivation and pitch
I am a graduate student using MJX for robot locomotion. With the recent release of Mujoco playground, I am interested in using MJX over IsaacGym, but a huge problem I run into is that I cannot effectively debug or iterate small code changes.
It takes more than 5 minutes to trace/compile my jitted reset and step functions, which are functionally very similar to those found here. I have already tried all the typical advice to reduce Jax compilation time (e.g. using jax control flow instead of native python), but the main bottleneck is inside the
mjx.forward
andmjx.step
functions.#1273 touches on this issue, but is somewhat old and the advice of "use Mujoco for development and MJX for actual training" is not very satisfactory. One interesting point from this issue was:
Is this still on the roadmap?
Feature Suggestion
How plausible is it to have a "debug mode" (i.e. using numpy, and doing things on cpu) under the hood for people to easily iterate on development? I am thinking something that allows users to still be able to use the mjx API (i.e. call
mjx.forward
andmjx.step
) but it does not use jax arrays under the hood.I am not too familiar with the internals of MJX, but considering jax.numpy and numpy share the same API, could we get most of the way there by just creating an alias for numpy? I.e.
import numpy as jp
.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: