Skip to content

Commit

Permalink
Allocate qM CSR matrix structure on mjData (currently empty).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733247419
Change-Id: I814e574e3161875e89433a6e03f2c9e200d8b1d6
  • Loading branch information
yuvaltassa authored and copybara-github committed Mar 4, 2025
1 parent 8f37ae1 commit c72c83a
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/includes/references.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ struct mjData_ {
int* B_rownnz; // body-dof: non-zeros in each row (nbody x 1)
int* B_rowadr; // body-dof: address of each row in B_colind (nbody x 1)
int* B_colind; // body-dof: column indices of non-zeros (nB x 1)
int* M_rownnz; // inertia: non-zeros in each row (nv x 1)
int* M_rowadr; // inertia: address of each row in M_colind (nv x 1)
int* M_colind; // inertia: column indices of non-zeros (nM x 1)
int* mapM2M; // index mapping from M (legacy) to M (CSR) (nM x 1)
int* C_rownnz; // reduced dof-dof: non-zeros in each row (nv x 1)
int* C_rowadr; // reduced dof-dof: address of each row in C_colind (nv x 1)
int* C_colind; // reduced dof-dof: column indices of non-zeros (nC x 1)
Expand Down
4 changes: 4 additions & 0 deletions include/mujoco/mjdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ struct mjData_ {
int* B_rownnz; // body-dof: non-zeros in each row (nbody x 1)
int* B_rowadr; // body-dof: address of each row in B_colind (nbody x 1)
int* B_colind; // body-dof: column indices of non-zeros (nB x 1)
int* M_rownnz; // inertia: non-zeros in each row (nv x 1)
int* M_rowadr; // inertia: address of each row in M_colind (nv x 1)
int* M_colind; // inertia: column indices of non-zeros (nM x 1)
int* mapM2M; // index mapping from M (legacy) to M (CSR) (nM x 1)
int* C_rownnz; // reduced dof-dof: non-zeros in each row (nv x 1)
int* C_rowadr; // reduced dof-dof: address of each row in C_colind (nv x 1)
int* C_colind; // reduced dof-dof: column indices of non-zeros (nC x 1)
Expand Down
4 changes: 4 additions & 0 deletions include/mujoco/mjxmacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,10 @@
X ( int, B_rownnz, nbody, 1 ) \
X ( int, B_rowadr, nbody, 1 ) \
X ( int, B_colind, nB, 1 ) \
X ( int, M_rownnz, nv, 1 ) \
X ( int, M_rowadr, nv, 1 ) \
X ( int, M_colind, nM, 1 ) \
X ( int, mapM2M, nM, 1 ) \
X ( int, C_rownnz, nv, 1 ) \
X ( int, C_rowadr, nv, 1 ) \
X ( int, C_colind, nC, 1 ) \
Expand Down
4 changes: 4 additions & 0 deletions mjx/mujoco/mjx/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ def make_data(
'B_rownnz': (m.nbody, jp.int32),
'B_rowadr': (m.nbody, jp.int32),
'B_colind': (m.nB, jp.int32),
'M_rownnz': (m.nv, jp.int32),
'M_rowadr': (m.nv, jp.int32),
'M_colind': (m.nM, jp.int32),
'mapM2M': (m.nM, jp.int32),
'C_rownnz': (m.nv, jp.int32),
'C_rowadr': (m.nv, jp.int32),
'C_colind': (m.nC, jp.int32),
Expand Down
8 changes: 8 additions & 0 deletions mjx/mujoco/mjx/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,10 @@ class Data(PyTreeNode):
B_rownnz: body-dof: non-zeros in each row (nbody,)
B_rowadr: body-dof: address of each row in B_colind (nbody,)
B_colind: body-dof: column indices of non-zeros (nB,)
M_rownnz: inertia: non-zeros in each row (nv,)
M_rowadr: inertia: address of each row in M_colind (nv,)
M_colind: inertia: column indices of non-zeros (nM,)
mapM2M: index mapping from M (legacy) to M (CSR) (nM,)
C_rownnz: reduced dof-dof: non-zeros in each row (nv,)
C_rowadr: reduced dof-dof: address of each row in C_colind (nv,)
C_colind: reduced dof-dof: column indices of non-zeros (nC,)
Expand Down Expand Up @@ -1445,6 +1449,10 @@ class Data(PyTreeNode):
B_rownnz: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
B_rowadr: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
B_colind: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
M_rownnz: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
M_rowadr: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
M_colind: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
mapM2M: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
C_rownnz: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
C_rowadr: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
C_colind: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
Expand Down
32 changes: 32 additions & 0 deletions python/mujoco/introspect/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5526,6 +5526,38 @@
doc='body-dof: column indices of non-zeros',
array_extent=('nB',),
),
StructFieldDecl(
name='M_rownnz',
type=PointerType(
inner_type=ValueType(name='int'),
),
doc='inertia: non-zeros in each row',
array_extent=('nv',),
),
StructFieldDecl(
name='M_rowadr',
type=PointerType(
inner_type=ValueType(name='int'),
),
doc='inertia: address of each row in M_colind',
array_extent=('nv',),
),
StructFieldDecl(
name='M_colind',
type=PointerType(
inner_type=ValueType(name='int'),
),
doc='inertia: column indices of non-zeros',
array_extent=('nM',),
),
StructFieldDecl(
name='mapM2M',
type=PointerType(
inner_type=ValueType(name='int'),
),
doc='index mapping from M (legacy) to M (CSR)',
array_extent=('nM',),
),
StructFieldDecl(
name='C_rownnz',
type=PointerType(
Expand Down
4 changes: 4 additions & 0 deletions unity/Runtime/Bindings/MjBindings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4942,6 +4942,10 @@ public unsafe struct mjData_ {
public int* B_rownnz;
public int* B_rowadr;
public int* B_colind;
public int* M_rownnz;
public int* M_rowadr;
public int* M_colind;
public int* mapM2M;
public int* C_rownnz;
public int* C_rowadr;
public int* C_colind;
Expand Down

0 comments on commit c72c83a

Please sign in to comment.