Skip to content

Commit

Permalink
small cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Velythyl committed Feb 16, 2024
1 parent b6cab64 commit 3da30d0
Showing 1 changed file with 5 additions and 65 deletions.
70 changes: 5 additions & 65 deletions brax/envs/wrappers/vsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,74 +95,14 @@ def _write_sys(sys, attr, val):
_write_sys(getattr(sys, attr[0]), attr[1:], val)})


def set_sys(sys, params: Dict[str, jp.ndarray]):
"""Sets params in the System."""
for k in params.keys():
sys = _write_sys(sys, k.split('.'), params[k])
return sys


def set_sys_capsules(sys, lengths, radii):
"""Sets the system with new capsule lengths/radii."""
sys2 = set_sys(sys, {'geoms.length': lengths})
sys2 = set_sys(sys2, {'geoms.radius': radii})

# we assume inertia.transform.pos is (0,0,0), as is often the case for
# capsules

# get the new joint transform
cur_len = sys.geoms[1].length[:, None]
joint_dir = jax.vmap(math.normalize)(sys.link.joint.pos)[0]
joint_dist = sys.link.joint.pos - 0.5 * cur_len * joint_dir
joint_transform = 0.5 * lengths[:, None] * joint_dir + joint_dist
sys2 = set_sys(sys2, {'link.joint.pos': joint_transform})

# get the new link transform
parent_idx = jp.array([sys.link_parents])
sys2 = set_sys(
sys2,
{
'link.transform.pos': -(
sys2.link.joint.pos
+ joint_dist
+ 0.5 * lengths[parent_idx].T * joint_dir
)
},
)
return sys2
# TODO traverse_sys and write_sys can probably be collapsed into one function


def util_vmap_set(sys, keys, vals):
dico = dict(zip(keys, vals))

return set_sys(sys, dico)


def randomize(sys, rng):
return set_sys(sys,
{'link.inertia.mass': sys.link.inertia.mass + jax.random.uniform(rng, shape=(sys.num_links(),))})


@jax.jit
def randomize_sys_capsules(
rng: jp.ndarray,
sys: base.System,
min_length: float = 0.0,
max_length: float = 0.0,
min_radius: float = 0.0,
max_radius: float = 0.0,
):
"""Randomizes joint offsets, assume capsule geoms appear in geoms[1]."""
rng, key1, key2 = jax.random.split(rng, 3)
length_u = jax.random.uniform(
key1, shape=(sys.num_links(),), minval=min_length, maxval=max_length
)
radius_u = jax.random.uniform(
key2, shape=(sys.num_links(),), minval=min_radius, maxval=max_radius
)
length = length_u + sys.geoms[1].length # pytype: disable=attribute-error
radius = radius_u + sys.geoms[1].radius # pytype: disable=attribute-error
return set_sys_capsules(sys, length, radius)


### RANDOMIZATION CONFIG LOGIC ###
Expand Down Expand Up @@ -337,9 +277,9 @@ def resamply(sys: System, vals: List[jp.ndarray]) -> System:

sys = jax.lax.cond(
mask,
resamply, # true
identity, # false
sys, vals # operands
resamply,
identity,
sys, vals
)
return sys

Expand Down Expand Up @@ -569,7 +509,7 @@ def reset(self, rng: jp.ndarray) -> State:

x = make_skrs(env, "./inertia.yaml")

env = DomainRandVSysWrapper(env, x, (4,8))
env = DomainRandVSysWrapper(env, x, 1)
#env = IdentityVSysWrapper(env)
#env = DomainCartesianVSysWrapper(env, x, DISCRETIZATION_LEVEL)
key = jax.random.PRNGKey(0)
Expand Down

0 comments on commit 3da30d0

Please sign in to comment.