From 3f32cc24791f1f9c32ce6b41d5ffa464ab0925d7 Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Fri, 10 Jan 2025 06:15:42 -0800 Subject: [PATCH] Add mjSpec binding to MJX. PiperOrigin-RevId: 714026500 Change-Id: I5cb3fe6cc975b3aa43d44185e3bc186b45845cf3 --- mjx/mujoco/mjx/_src/support.py | 173 +++++++++++++++++++++++++++- mjx/mujoco/mjx/_src/support_test.py | 86 ++++++++++++++ 2 files changed, 257 insertions(+), 2 deletions(-) diff --git a/mjx/mujoco/mjx/_src/support.py b/mjx/mujoco/mjx/_src/support.py index 1f3c8137ad..18992535c8 100644 --- a/mjx/mujoco/mjx/_src/support.py +++ b/mjx/mujoco/mjx/_src/support.py @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== """Engine support functions.""" -from typing import Optional, Tuple, Union +from collections.abc import Sequence +from typing import Optional, Tuple, Union, Any import jax from jax import numpy as jp @@ -236,7 +237,7 @@ def _getadr( def id2name( m: Union[Model, mujoco.MjModel], typ: mujoco._enums.mjtObj, i: int ) -> Optional[str]: - """Gets the name of an object with the specified mjtObj type and id. + """Gets the name of an object with the specified mjtObj type and ids. See mujoco.id2name for more info. @@ -284,6 +285,174 @@ def name2id( return names_map.get(name, -1) +class BindModel(object): + """Class holding the requested MJX Model and spec id for binding a spec to Model.""" + + def __init__(self, model: Model, specs: Sequence[Any]): + self.model = model + try: + iter(specs) + except TypeError: + specs = [specs] + ids = [] + for spec in specs: + match spec: + case mujoco.MjsBody(): + self.prefix = 'body_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name)) + case mujoco.MjsJoint(): + self.prefix = 'jnt_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name)) + case mujoco.MjsGeom(): + self.prefix = 'geom_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name)) + case mujoco.MjsSite(): + self.prefix = 'site_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name)) + case mujoco.MjsLight(): + self.prefix = 'light_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name)) + case mujoco.MjsCamera(): + self.prefix = 'cam_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name)) + case mujoco.MjsMesh(): + self.prefix = 'mesh_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_MESH, spec.name)) + case mujoco.MjsHfield(): + self.prefix = 'hfield_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_HFIELD, spec.name)) + case mujoco.MjsPair(): + self.prefix = 'pair_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_PAIR, spec.name)) + case mujoco.MjsTendon(): + self.prefix = 'tendon_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name)) + case mujoco.MjsActuator(): + self.prefix = 'actuator_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name)) + case mujoco.MjsSensor(): + self.prefix = 'sensor_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name)) + case mujoco.MjsNumeric(): + self.prefix = 'numeric_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_NUMERIC, spec.name)) + case mujoco.MjsText(): + self.prefix = 'text_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TEXT, spec.name)) + case mujoco.MjsTuple(): + self.prefix = 'tuple_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TUPLE, spec.name)) + case mujoco.MjsKey(): + self.prefix = 'key_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_KEY, spec.name)) + case mujoco.MjsEquality(): + self.prefix = 'eq_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name)) + case mujoco.MjsExclude(): + self.prefix = 'exclude_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_EXCLUDE, spec.name)) + case mujoco.MjsSkin(): + self.prefix = 'skin_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SKIN, spec.name)) + case mujoco.MjsMaterial(): + self.prefix = 'material_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_MATERIAL, spec.name)) + case _: + raise ValueError('invalid spec type') + if len(ids) == 1: + self.id = ids[0] + else: + self.id = ids + + def __getattr__(self, name: str): + return getattr(self.model, self.prefix + name)[self.id, :] + + +def _bind_model(self: Model, obj: Sequence[Any]) -> BindModel: + """Bind a Mujoco spec to an MJX Model.""" + return BindModel(self, obj) + + +class BindData(object): + """Class holding the requested MJX Data and spec id for binding a spec to Data.""" + + def __init__(self, data: Data, model: Model, specs: Sequence[Any]): + self.data = data + try: + iter(specs) + except TypeError: + specs = [specs] + ids = [] + for spec in specs: + match spec: + case mujoco.MjsBody(): + self.prefix = '' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name)) + case mujoco.MjsJoint(): + self.prefix = 'jnt_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name)) + case mujoco.MjsGeom(): + self.prefix = 'geom_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name)) + case mujoco.MjsSite(): + self.prefix = 'site_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name)) + case mujoco.MjsLight(): + self.prefix = 'light_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name)) + case mujoco.MjsCamera(): + self.prefix = 'cam_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name)) + case mujoco.MjsTendon(): + self.prefix = 'tendon_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name)) + case mujoco.MjsActuator(): + self.prefix = 'actuator_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name)) + case mujoco.MjsSensor(): + self.prefix = 'sensor_' + ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name)) + case _: + raise ValueError('invalid spec type') + if len(ids) == 1: + self.id = ids[0] + else: + self.id = ids + + def __getname(self, name: str): + try: + getattr(self.data, self.prefix + name) + return self.prefix + name + except AttributeError: + try: + getattr(self.data, name) + return name + except AttributeError as e: + raise ValueError(f'invalid name: {name}') from e + + def __getattr__(self, name: str): + return getattr(self.data, self.__getname(name))[self.id, ...] + + def set(self, name: str, value: jax.Array) -> Data: + """Set the value of an array in an MJX Data.""" + array = getattr(self.data, self.__getname(name)) + if len(value) == 1: + array = array.at[self.id].set(value[0]) + else: + for i, v in enumerate(value): + array = array.at[self.id[i]].set(v) + return self.data.replace(**{self.__getname(name): array}) + + +def _bind_data(self: Data, model: Model, obj: Sequence[Any]) -> BindData: + """Bind a Mujoco spec to an MJX Data.""" + return BindData(self, model, obj) + + +Model.bind = _bind_model +Data.bind = _bind_data + + def _decode_pyramid( pyramid: jax.Array, mu: jax.Array, condim: int ) -> jax.Array: diff --git a/mjx/mujoco/mjx/_src/support_test.py b/mjx/mujoco/mjx/_src/support_test.py index c955ba72f3..9add48b372 100644 --- a/mjx/mujoco/mjx/_src/support_test.py +++ b/mjx/mujoco/mjx/_src/support_test.py @@ -157,6 +157,92 @@ def test_names_and_ids(self): i = i if n is not None else -1 self.assertEqual(support.name2id(mx, obj, n), i) + def test_bind(self): + xml = """ + + + + + + + + + + + + + + + + + + + + + + + """ + + s = mujoco.MjSpec.from_string(xml) + m = s.compile() + d = mujoco.MjData(m) + mx = mjx.put_model(m) + dx = mjx.put_data(m, d) + mujoco.mj_step(m, d) + dx = mjx.step(mx, dx) + + # test getting + np.testing.assert_array_equal(mx.bind(s.bodies).pos, m.body_pos) + np.testing.assert_array_equal(dx.bind(mx, s.bodies).xpos, d.xpos) + for i in range(m.nbody): + np.testing.assert_array_equal(m.bind(s.bodies[i]).pos, m.body_pos[i, :]) + np.testing.assert_array_equal(mx.bind(s.bodies[i]).pos, m.body_pos[i, :]) + np.testing.assert_array_equal(d.bind(s.bodies[i]).xpos, d.xpos[i, :]) + np.testing.assert_array_equal( + dx.bind(mx, s.bodies[i]).xpos, d.xpos[i, :] + ) + + np.testing.assert_array_equal(mx.bind(s.geoms).size, m.geom_size) + np.testing.assert_array_equal(dx.bind(mx, s.geoms).xpos, d.geom_xpos) + for i in range(m.ngeom): + np.testing.assert_array_equal(m.bind(s.geoms[i]).size, m.geom_size[i, :]) + np.testing.assert_array_equal(mx.bind(s.geoms[i]).size, m.geom_size[i, :]) + np.testing.assert_array_equal(d.bind(s.geoms[i]).xpos, d.geom_xpos[i, :]) + np.testing.assert_array_equal( + dx.bind(mx, s.geoms[i]).xpos, d.geom_xpos[i, :] + ) + + np.testing.assert_array_equal(mx.bind(s.joints).axis, m.jnt_axis) + for i in range(m.njnt): + np.testing.assert_array_equal(m.bind(s.joints[i]).axis, m.jnt_axis[i, :]) + np.testing.assert_array_equal(mx.bind(s.joints[i]).axis, m.jnt_axis[i, :]) + + np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, d.ctrl) + for i in range(m.nu): + np.testing.assert_array_equal(d.bind(s.actuators[i]).ctrl, d.ctrl[i]) + np.testing.assert_array_equal( + dx.bind(mx, s.actuators[i]).ctrl, d.ctrl[i] + ) + + # test setting + np.testing.assert_array_equal(d.ctrl, [0, 0, 0]) + np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, d.ctrl) + dx2 = dx.bind(mx, s.actuators).set('ctrl', [1, 2, 3]) + np.testing.assert_array_equal(dx2.bind(mx, s.actuators).ctrl, [1, 2, 3]) + np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0]) + dx3 = dx.bind(mx, s.actuators[1:]).set('ctrl', [4, 5]) + np.testing.assert_array_equal(dx3.bind(mx, s.actuators).ctrl, [0, 4, 5]) + np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0]) + dx4 = dx.bind(mx, s.actuators[1]).set('ctrl', [6]) + np.testing.assert_array_equal(dx4.bind(mx, s.actuators).ctrl, [0, 6, 0]) + np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0]) + + # test invalid name + with self.assertRaises(ValueError): + print(dx.bind(mx, s.actuators).actuator_ctrl) + with self.assertRaises(ValueError): + print(dx.bind(mx, s.actuators).set('actuator_ctrl', [1, 2, 3])) + _CONTACTS = """