Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Energy calculations in MJX #2314

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

simeon-ned
Copy link

This PR adds support for kinetic and potential energy calculations in MJX, matching MuJoCo's implementation.

Key changes:

  • Add energy_pos() and energy_vel() functions to sensor.py for calculating potential and kinetic energy respectively
  • Add EnableBit.ENERGY flag support to control energy calculations

Copy link

google-cla bot commented Dec 27, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@yuvaltassa yuvaltassa requested a review from erikfrey December 28, 2024 10:38
@erikfrey
Copy link
Collaborator

@simeon-ned what a nice PR! Thank you for creating it. Things are still slow due to the holidays, but someone will get to this soon.

@simeon-ned
Copy link
Author

Hello @erikfrey, is there anything I can do to help with this PR? I could add tests similar to those implemented in mjx src.

raise NotImplementedError(f'{mujoco.mjtEnableBit(2 ** i)}')
# Check enable flags using enum pattern
if types.EnableBit(o.enableflags) not in set(types.EnableBit) and o.enableflags != 0:
raise NotImplementedError(f'{mujoco.mjtEnableBit(o.enableflags)}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think mujoco.mjtEnableBit(o.enableflags) works when you have multiple bits set, you will get an error like this:

ValueError: Invalid int value for mjtEnableBit: 3


# Add joint spring potential energy
if not m.opt.disableflags & DisableBit.PASSIVE:
for i in range(m.njnt):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I challenge you to figure out how to do this without this for loop iterating over joints. You should find code elsewhere that does this: look for usage of scan.flat in smooth.py for some examples.

scan is a bit gnarly and not well super well documented, but should be used wherever possible.


# Add tendon spring potential energy
if not m.opt.disableflags & DisableBit.PASSIVE:
for i in range(m.ntendon):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this for loop can be removed, and you should not need scan at all.

you can do something like:

lower = m.tendon_lengthspring[0::2]
upper = m.tendon_lengthspring[1::2]
length = d.ten_length

displacement = jp.where(length > upper, upper - length, 0)
displacement = jp.where(length < lower, lower - length, displacement)
...

"""

# Initialize potential energy
energy = jp.array(0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mega nitpick: please move this variable definition to right after the if statement guard


# Add tendon spring potential energy using vectorized operations
if not m.opt.disableflags & DisableBit.PASSIVE:
if not m.opt.disableflags & DisableBit.PASSIVE & m.tendon_lengthspring.size > 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order of logical operations is probably incorrect. This evaluates to true even if m.tendon_lengthspring.size > 0 is false, which leads to access to non-existent elements of this array.
(not m.opt.disableflags & DisableBit.PASSIVE) & (m.tendon_lengthspring.size > 0) should work.


elif jnt_type in (JointType.FREE, JointType.BALL):
# Convert quaternion difference to angular displacement
quat = qpos[padr:padr+4]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably slicing syntax in this form is not supported, as jax requires static start/stop/step to be used with NumPy indexing syntax.

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[1])>with<BatchTrace> with
  val = Traced<ShapedArray(int32[1,1])>with<DynamicJaxprTrace>
  batch_dim = 0, Traced<ShapedArray(int32[1])>with<BatchTrace> with
  val = Traced<ShapedArray(int32[1,1])>with<DynamicJaxprTrace>
  batch_dim = 0, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

This happens when I try to jit compile functions which take derivatives of energies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants