Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to Iterate Quickly in MJX #2441

Open
mpiseno opened this issue Feb 20, 2025 · 1 comment
Open

Ability to Iterate Quickly in MJX #2441

mpiseno opened this issue Feb 20, 2025 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@mpiseno
Copy link

mpiseno commented Feb 20, 2025

The feature, motivation and pitch

I am a graduate student using MJX for robot locomotion. With the recent release of Mujoco playground, I am interested in using MJX over IsaacGym, but a huge problem I run into is that I cannot effectively debug or iterate small code changes.

It takes more than 5 minutes to trace/compile my jitted reset and step functions, which are functionally very similar to those found here. I have already tried all the typical advice to reduce Jax compilation time (e.g. using jax control flow instead of native python), but the main bottleneck is inside the mjx.forward and mjx.step functions.

#1273 touches on this issue, but is somewhat old and the advice of "use Mujoco for development and MJX for actual training" is not very satisfactory. One interesting point from this issue was:

In our roadmap are some tools for AOT'ing JIT / caching JIT output to disk, so that you can rapidly develop in MJX without waiting for JIT every time you restart your process... no ETA on that work though, but that will be a big help for sure.

Is this still on the roadmap?

Feature Suggestion
How plausible is it to have a "debug mode" (i.e. using numpy, and doing things on cpu) under the hood for people to easily iterate on development? I am thinking something that allows users to still be able to use the mjx API (i.e. call mjx.forward and mjx.step) but it does not use jax arrays under the hood.

I am not too familiar with the internals of MJX, but considering jax.numpy and numpy share the same API, could we get most of the way there by just creating an alias for numpy? I.e. import numpy as jp.

Alternatives

No response

Additional context

No response

@mpiseno mpiseno added the enhancement New feature or request label Feb 20, 2025
@erikfrey
Copy link
Collaborator

erikfrey commented Mar 7, 2025

@mpiseno We hear you and we feel the same pain! We have some pretty neat improvements to JIT speed and more granular JIT caching that we will share in a few weeks' time.

I'll come back and leave a comment here when we're ready to share more.

p.s. we actually did use to have a way of switching between JAX and numpy in Brax, called jumpy:

https://github.com/google/brax/blob/main/brax/v1/jumpy.py

You're welcome to play with the idea, but let me warn you that while at first glance it seems elegant, the idea of swapping numpy for JAX is full of really gnarly gotchas. The number of errors and edge cases ultimately made it not worth it in our eyes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants