Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating MJX data struct on GPU with make_data() appears to be slower than put_data() #2461

Open
2 tasks done
Balint-H opened this issue Feb 25, 2025 · 0 comments
Open
2 tasks done
Assignees
Labels
bug Something isn't working MJX Using JAX to run on GPU

Comments

@Balint-H
Copy link
Collaborator

Balint-H commented Feb 25, 2025

Intro

Hi!

I am using MJX for RL in motor control scenarios.

My setup

I use MJX (prerelease/HEAD version) on a WSL setup.

What's happening? What did you expect?

I expected initializing an MJX model to be faster with calling make_data() on a MJX model, when compared to creating the data struct with CPU MuJoCo, then calling put_data(). I expected this to be especially true with a GPU backend. However, I observe the contrary when trying to quantify execution times using timeit. (~2x as much execution time using make_data(), more with complex models).

When I use a CPU backend, make_data() is indeed faster as expected.

Steps for reproduction

Example code to with `timeit` and the MuJoCo humanoid:
import timeit
from etils import epath
import jax
import mujoco
from mujoco import mjx

jax.config.update('jax_platform_name', 'gpu')

path = (epath.Path(epath.resource_path('mujoco')) / (
        'mjx/test_data/humanoid/humanoid.xml')).as_posix()

def load_with_put(mjcf_file=path):
    mj_model = mujoco.MjModel.from_xml_path(mjcf_file)
    mjx_model = mjx.put_model(mj_model)
    mj_data = mujoco.MjData(mj_model)
    mjx_data = mjx.put_data(mj_model, mj_data)
    return mjx_data, mjx_model

def load_with_make(mjcf_file=path):
    mj_model = mujoco.MjModel.from_xml_path(mjcf_file)
    mjx_model = mjx.put_model(mj_model)
    mjx_data = mjx.make_data(mjx_model)
    return mjx_data, mjx_model

print(timeit.timeit("load_with_put()", setup="from __main__ import load_with_put, load_with_make", number=100))
print(timeit.timeit("load_with_make()", setup="from __main__ import load_with_put, load_with_make", number=100))

Minimal model for reproduction

No response

Code required for reproduction

No response

Confirmations

@Balint-H Balint-H added bug Something isn't working MJX Using JAX to run on GPU labels Feb 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working MJX Using JAX to run on GPU
Projects
None yet
Development

No branches or pull requests

2 participants