Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Energy calculations in MJX #2314

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion mjx/mujoco/mjx/_src/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,13 @@ def forward(m: Model, d: Data) -> Data:
"""Forward dynamics."""
d = fwd_position(m, d)
d = sensor.sensor_pos(m, d)
d = sensor.energy_pos(m, d)
d = fwd_velocity(m, d)
d = sensor.sensor_vel(m, d)
d = fwd_actuation(m, d)
d = sensor.energy_vel(m, d)
d = fwd_acceleration(m, d)
d = sensor.sensor_acc(m, d)

if d.efc_J.size == 0:
d = d.replace(qacc=d.qacc_smooth)
return d
Expand Down
8 changes: 5 additions & 3 deletions mjx/mujoco/mjx/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def _make_option(
if o.solver not in set(types.SolverType):
raise NotImplementedError(f'{mujoco.mjtSolver(o.solver)}')

for i in range(mujoco.mjtEnableBit.mjNENABLE):
if o.enableflags & 2**i:
raise NotImplementedError(f'{mujoco.mjtEnableBit(2 ** i)}')
# Check enable flags using enum pattern
if types.EnableBit(o.enableflags) not in set(types.EnableBit) and o.enableflags != 0:
raise NotImplementedError(f'{mujoco.mjtEnableBit(o.enableflags)}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think mujoco.mjtEnableBit(o.enableflags) works when you have multiple bits set, you will get an error like this:

ValueError: Invalid int value for mjtEnableBit: 3


has_fluid_params = o.density > 0 or o.viscosity > 0 or o.wind.any()
implicitfast = o.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST
Expand All @@ -71,6 +71,7 @@ def _make_option(
fields['jacobian'] = types.JacobianType(o.jacobian)
fields['solver'] = types.SolverType(o.solver)
fields['disableflags'] = types.DisableBit(o.disableflags)
fields['enableflags'] = types.EnableBit(o.enableflags)
fields['has_fluid_params'] = has_fluid_params

return types.Option(**fields)
Expand Down Expand Up @@ -363,6 +364,7 @@ def make_data(
'_qM_sparse': (m.nM, float),
'_qLD_sparse': (m.nM, float),
'_qLDiagInv_sparse': (m.nv, float),
'energy': (2, float),
}

if not _full_compat:
Expand Down
85 changes: 85 additions & 0 deletions mjx/mujoco/mjx/_src/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from mujoco.mjx._src import ray
from mujoco.mjx._src import smooth
from mujoco.mjx._src import support
from mujoco.mjx._src import scan
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import Model
from mujoco.mjx._src.types import ObjType
from mujoco.mjx._src.types import SensorType
from mujoco.mjx._src.types import JointType
from mujoco.mjx._src.types import EnableBit
# pylint: enable=g-importing-member
import numpy as np

Expand Down Expand Up @@ -600,3 +603,85 @@ def _framelinacc(cvel, cacc, offset):
)

return d.replace(sensordata=sensordata)


def energy_pos(m: Model, d: Data) -> Data:
"""Calculates position-dependent energy (potential).
"""

if not m.opt.enableflags & EnableBit.ENERGY:
return d

# Initialize potential energy
energy = jp.array(0.0)

# Add gravitational potential energy for each body
if not m.opt.disableflags & DisableBit.GRAVITY:
energy = -jp.sum(m.body_mass[1:] * jp.dot(d.xipos[1:,:], m.opt.gravity))

# Add joint spring potential energy using scan.flat
if not m.opt.disableflags & DisableBit.PASSIVE:
def spring_energy(jnt_type, stiffness, qpos, qpos_spring, padr):

if jnt_type == JointType.FREE:
# Position springs
quat = qpos[padr:padr+4]
quat = math.normalize(quat)
dif = quat - qpos_spring[padr:padr+4]
energy = 0.5 * stiffness * jp.dot(dif[:3], dif[:3])

elif jnt_type in (JointType.FREE, JointType.BALL):
# Convert quaternion difference to angular displacement
quat = qpos[padr:padr+4]
quat = math.normalize(quat)
dif = math.quat_sub(quat, qpos_spring[padr:padr+4])
energy = 0.5 * stiffness * jp.dot(dif, dif)

elif jnt_type in (JointType.SLIDE, JointType.HINGE):
dif = qpos[padr] - qpos_spring[padr]
energy = 0.5 * stiffness * dif * dif

return energy

spring_energy = scan.flat(
m,
spring_energy,
'jjqqj', # input types: jnt_type, stiffness, qpos, qpos_spring, padr
'j', # output type: energy per joint
m.jnt_type,
m.jnt_stiffness,
d.qpos,
m.qpos_spring,
jp.array(m.jnt_qposadr),
group_by='j'
)

energy += jp.sum(spring_energy)

# Add tendon spring potential energy using vectorized operations
if not m.opt.disableflags & DisableBit.PASSIVE & m.tendon_lengthspring.size > 0:
# Get lower/upper bounds and current lengths
lower = m.tendon_lengthspring[::2] # Even indices
upper = m.tendon_lengthspring[1::2] # Odd indices
length = d.ten_length

# Compute displacements using vectorized operations
displacement = jp.where(length > upper, upper - length, 0.0)
displacement = jp.where(length < lower, lower - length, displacement)

# Compute spring energy for all tendons at once
energy += 0.5 * jp.sum(m.tendon_stiffness * displacement * displacement)

return d.replace(energy=d.energy.at[0].set(energy))


def energy_vel(m: Model, d: Data) -> Data:
"""Calculates velocity-dependent energy (kinetic).
"""
if not m.opt.enableflags & EnableBit.ENERGY:
return d

vec = support.mul_m(m, d, d.qvel)
energy = 0.5 * jp.dot(vec, d.qvel)

return d.replace(energy=d.energy.at[1].set(energy))
13 changes: 12 additions & 1 deletion mjx/mujoco/mjx/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ class DisableBit(enum.IntFlag):
# unsupported: MIDPHASE


class EnableBit(enum.IntFlag):
"""Enable optional feature bitflags.

Members:
ENERGY: enable energy computation
"""

ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY

class JointType(enum.IntEnum):
"""Type of degree of freedom.

Expand Down Expand Up @@ -482,7 +491,7 @@ class Option(PyTreeNode):
noslip_iterations: int = _restricted_to('mujoco')
ccd_iterations: int = _restricted_to('mujoco')
disableflags: DisableBit
enableflags: int
enableflags: EnableBit
disableactuator: int
sdf_initpoints: int = _restricted_to('mujoco')
sdf_iterations: int = _restricted_to('mujoco')
Expand Down Expand Up @@ -1214,6 +1223,7 @@ class Data(PyTreeNode):
ncon: number of contacts
solver_niter: number of solver iterations
time: simulation time
energy: potential, kinetic energy (2, )
qpos: position (nq,)
qvel: velocity (nv,)
act: actuator activation (na,)
Expand Down Expand Up @@ -1341,6 +1351,7 @@ class Data(PyTreeNode):
solver_niter: jax.Array
# global properties:
time: jax.Array
energy: jax.Array
# state:
qpos: jax.Array
qvel: jax.Array
Expand Down