Skip to content

Commit

Permalink
Add mjSpec binding to MJX.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714026500
Change-Id: I5cb3fe6cc975b3aa43d44185e3bc186b45845cf3
  • Loading branch information
quagla authored and copybara-github committed Jan 10, 2025
1 parent 6f7fba7 commit 3f32cc2
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 2 deletions.
173 changes: 171 additions & 2 deletions mjx/mujoco/mjx/_src/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Engine support functions."""
from typing import Optional, Tuple, Union
from collections.abc import Sequence
from typing import Optional, Tuple, Union, Any

import jax
from jax import numpy as jp
Expand Down Expand Up @@ -236,7 +237,7 @@ def _getadr(
def id2name(
m: Union[Model, mujoco.MjModel], typ: mujoco._enums.mjtObj, i: int
) -> Optional[str]:
"""Gets the name of an object with the specified mjtObj type and id.
"""Gets the name of an object with the specified mjtObj type and ids.
See mujoco.id2name for more info.
Expand Down Expand Up @@ -284,6 +285,174 @@ def name2id(
return names_map.get(name, -1)


class BindModel(object):
"""Class holding the requested MJX Model and spec id for binding a spec to Model."""

def __init__(self, model: Model, specs: Sequence[Any]):
self.model = model
try:
iter(specs)
except TypeError:
specs = [specs]
ids = []
for spec in specs:
match spec:
case mujoco.MjsBody():
self.prefix = 'body_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name))
case mujoco.MjsJoint():
self.prefix = 'jnt_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name))
case mujoco.MjsGeom():
self.prefix = 'geom_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name))
case mujoco.MjsSite():
self.prefix = 'site_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name))
case mujoco.MjsLight():
self.prefix = 'light_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name))
case mujoco.MjsCamera():
self.prefix = 'cam_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name))
case mujoco.MjsMesh():
self.prefix = 'mesh_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_MESH, spec.name))
case mujoco.MjsHfield():
self.prefix = 'hfield_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_HFIELD, spec.name))
case mujoco.MjsPair():
self.prefix = 'pair_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_PAIR, spec.name))
case mujoco.MjsTendon():
self.prefix = 'tendon_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name))
case mujoco.MjsActuator():
self.prefix = 'actuator_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name))
case mujoco.MjsSensor():
self.prefix = 'sensor_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name))
case mujoco.MjsNumeric():
self.prefix = 'numeric_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_NUMERIC, spec.name))
case mujoco.MjsText():
self.prefix = 'text_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TEXT, spec.name))
case mujoco.MjsTuple():
self.prefix = 'tuple_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TUPLE, spec.name))
case mujoco.MjsKey():
self.prefix = 'key_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_KEY, spec.name))
case mujoco.MjsEquality():
self.prefix = 'eq_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name))
case mujoco.MjsExclude():
self.prefix = 'exclude_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_EXCLUDE, spec.name))
case mujoco.MjsSkin():
self.prefix = 'skin_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SKIN, spec.name))
case mujoco.MjsMaterial():
self.prefix = 'material_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_MATERIAL, spec.name))
case _:
raise ValueError('invalid spec type')
if len(ids) == 1:
self.id = ids[0]
else:
self.id = ids

def __getattr__(self, name: str):
return getattr(self.model, self.prefix + name)[self.id, :]


def _bind_model(self: Model, obj: Sequence[Any]) -> BindModel:
"""Bind a Mujoco spec to an MJX Model."""
return BindModel(self, obj)


class BindData(object):
"""Class holding the requested MJX Data and spec id for binding a spec to Data."""

def __init__(self, data: Data, model: Model, specs: Sequence[Any]):
self.data = data
try:
iter(specs)
except TypeError:
specs = [specs]
ids = []
for spec in specs:
match spec:
case mujoco.MjsBody():
self.prefix = ''
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name))
case mujoco.MjsJoint():
self.prefix = 'jnt_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name))
case mujoco.MjsGeom():
self.prefix = 'geom_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name))
case mujoco.MjsSite():
self.prefix = 'site_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name))
case mujoco.MjsLight():
self.prefix = 'light_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name))
case mujoco.MjsCamera():
self.prefix = 'cam_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name))
case mujoco.MjsTendon():
self.prefix = 'tendon_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name))
case mujoco.MjsActuator():
self.prefix = 'actuator_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name))
case mujoco.MjsSensor():
self.prefix = 'sensor_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name))
case _:
raise ValueError('invalid spec type')
if len(ids) == 1:
self.id = ids[0]
else:
self.id = ids

def __getname(self, name: str):
try:
getattr(self.data, self.prefix + name)
return self.prefix + name
except AttributeError:
try:
getattr(self.data, name)
return name
except AttributeError as e:
raise ValueError(f'invalid name: {name}') from e

def __getattr__(self, name: str):
return getattr(self.data, self.__getname(name))[self.id, ...]

def set(self, name: str, value: jax.Array) -> Data:
"""Set the value of an array in an MJX Data."""
array = getattr(self.data, self.__getname(name))
if len(value) == 1:
array = array.at[self.id].set(value[0])
else:
for i, v in enumerate(value):
array = array.at[self.id[i]].set(v)
return self.data.replace(**{self.__getname(name): array})


def _bind_data(self: Data, model: Model, obj: Sequence[Any]) -> BindData:
"""Bind a Mujoco spec to an MJX Data."""
return BindData(self, model, obj)


Model.bind = _bind_model
Data.bind = _bind_data


def _decode_pyramid(
pyramid: jax.Array, mu: jax.Array, condim: int
) -> jax.Array:
Expand Down
86 changes: 86 additions & 0 deletions mjx/mujoco/mjx/_src/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,92 @@ def test_names_and_ids(self):
i = i if n is not None else -1
self.assertEqual(support.name2id(mx, obj, n), i)

def test_bind(self):
xml = """
<mujoco model="test_bind_model">
<worldbody>
<body pos="1 2 3" name="body1">
<joint axis="1 0 0" type="slide" name="joint1"/>
<geom size="1 2 3" type="box" name="geom1"/>
</body>
<body pos="4 5 6" name="body2">
<joint axis="0 1 0" type="slide" name="joint2"/>
<geom size="4 5 6" type="box" name="geom2"/>
</body>
<body pos="7 8 9" name="body3">
<joint axis="0 0 1" type="slide" name="joint3"/>
<geom size="7 8 9" type="box" name="geom3"/>
</body>
</worldbody>
<actuator>
<motor name="actuator1" joint="joint1"/>
<motor name="actuator2" joint="joint2"/>
<motor name="actuator3" joint="joint3"/>
</actuator>
</mujoco>
"""

s = mujoco.MjSpec.from_string(xml)
m = s.compile()
d = mujoco.MjData(m)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)
mujoco.mj_step(m, d)
dx = mjx.step(mx, dx)

# test getting
np.testing.assert_array_equal(mx.bind(s.bodies).pos, m.body_pos)
np.testing.assert_array_equal(dx.bind(mx, s.bodies).xpos, d.xpos)
for i in range(m.nbody):
np.testing.assert_array_equal(m.bind(s.bodies[i]).pos, m.body_pos[i, :])
np.testing.assert_array_equal(mx.bind(s.bodies[i]).pos, m.body_pos[i, :])
np.testing.assert_array_equal(d.bind(s.bodies[i]).xpos, d.xpos[i, :])
np.testing.assert_array_equal(
dx.bind(mx, s.bodies[i]).xpos, d.xpos[i, :]
)

np.testing.assert_array_equal(mx.bind(s.geoms).size, m.geom_size)
np.testing.assert_array_equal(dx.bind(mx, s.geoms).xpos, d.geom_xpos)
for i in range(m.ngeom):
np.testing.assert_array_equal(m.bind(s.geoms[i]).size, m.geom_size[i, :])
np.testing.assert_array_equal(mx.bind(s.geoms[i]).size, m.geom_size[i, :])
np.testing.assert_array_equal(d.bind(s.geoms[i]).xpos, d.geom_xpos[i, :])
np.testing.assert_array_equal(
dx.bind(mx, s.geoms[i]).xpos, d.geom_xpos[i, :]
)

np.testing.assert_array_equal(mx.bind(s.joints).axis, m.jnt_axis)
for i in range(m.njnt):
np.testing.assert_array_equal(m.bind(s.joints[i]).axis, m.jnt_axis[i, :])
np.testing.assert_array_equal(mx.bind(s.joints[i]).axis, m.jnt_axis[i, :])

np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, d.ctrl)
for i in range(m.nu):
np.testing.assert_array_equal(d.bind(s.actuators[i]).ctrl, d.ctrl[i])
np.testing.assert_array_equal(
dx.bind(mx, s.actuators[i]).ctrl, d.ctrl[i]
)

# test setting
np.testing.assert_array_equal(d.ctrl, [0, 0, 0])
np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, d.ctrl)
dx2 = dx.bind(mx, s.actuators).set('ctrl', [1, 2, 3])
np.testing.assert_array_equal(dx2.bind(mx, s.actuators).ctrl, [1, 2, 3])
np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0])
dx3 = dx.bind(mx, s.actuators[1:]).set('ctrl', [4, 5])
np.testing.assert_array_equal(dx3.bind(mx, s.actuators).ctrl, [0, 4, 5])
np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0])
dx4 = dx.bind(mx, s.actuators[1]).set('ctrl', [6])
np.testing.assert_array_equal(dx4.bind(mx, s.actuators).ctrl, [0, 6, 0])
np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0])

# test invalid name
with self.assertRaises(ValueError):
print(dx.bind(mx, s.actuators).actuator_ctrl)
with self.assertRaises(ValueError):
print(dx.bind(mx, s.actuators).set('actuator_ctrl', [1, 2, 3]))

_CONTACTS = """
<mujoco>
<worldbody>
Expand Down

0 comments on commit 3f32cc2

Please sign in to comment.