diff --git a/MANIFEST.in b/MANIFEST.in index 06b1a0ed4..fba54ffcd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include brax/envs/assets/*.xml -recursive-include brax/experimental/barkour_v0 *.csv *.stl *.xml +recursive-include brax/experimental/barkour_vb *.csv *.stl *.xml recursive-include brax/test_data *.xml *.stl *.obj *.urdf recursive-include brax/visualizer * diff --git a/brax/base.py b/brax/base.py index 2242d9a24..35d22f81d 100644 --- a/brax/base.py +++ b/brax/base.py @@ -531,15 +531,15 @@ def qd_idx(self, link_type: str) -> jax.Array: def q_size(self) -> int: """Returns the size of the q vector (joint position) for this system.""" - return sum([Q_WIDTHS[t] for t in self.link_types]) + return self.nq def qd_size(self) -> int: """Returns the size of the qd vector (joint velocity) for this system.""" - return sum([QD_WIDTHS[t] for t in self.link_types]) + return self.nv def act_size(self) -> int: """Returns the act dimension for the system.""" - return self.actuator.q_id.shape[0] + return self.nu # below are some operation dispatch derivations diff --git a/brax/envs/ant.py b/brax/envs/ant.py index aa6b5eee1..2ff166b01 100644 --- a/brax/envs/ant.py +++ b/brax/envs/ant.py @@ -233,6 +233,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Run one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt diff --git a/brax/envs/fast.py b/brax/envs/fast.py index 376df3b3a..f7351b71f 100644 --- a/brax/envs/fast.py +++ b/brax/envs/fast.py @@ -43,6 +43,7 @@ def reset(self, rng: jax.Array) -> State: return State(pipeline_state, obs, reward, done) def step(self, state: State, action: jax.Array) -> State: + assert state.pipeline_state is not None self._step_count += 1 vel = state.pipeline_state.xd.vel + (action > 0) * self._dt pos = state.pipeline_state.x.pos + vel * self._dt diff --git a/brax/envs/half_cheetah.py b/brax/envs/half_cheetah.py index 527c94e14..e99eba608 100644 --- a/brax/envs/half_cheetah.py +++ b/brax/envs/half_cheetah.py @@ -178,6 +178,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) x_velocity = ( diff --git a/brax/envs/hopper.py b/brax/envs/hopper.py index 373248546..516a63361 100644 --- a/brax/envs/hopper.py +++ b/brax/envs/hopper.py @@ -223,6 +223,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) x_velocity = ( diff --git a/brax/envs/pusher.py b/brax/envs/pusher.py index 3abe82f09..e1515d013 100644 --- a/brax/envs/pusher.py +++ b/brax/envs/pusher.py @@ -193,6 +193,7 @@ def reset(self, rng: jax.Array) -> State: return State(pipeline_state, obs, reward, done, metrics) def step(self, state: State, action: jax.Array) -> State: + assert state.pipeline_state is not None x_i = state.pipeline_state.x.vmap().do( base.Transform.create(pos=self.sys.link.inertia.transform.pos) ) diff --git a/brax/envs/walker2d.py b/brax/envs/walker2d.py index d0950e33d..54379e637 100644 --- a/brax/envs/walker2d.py +++ b/brax/envs/walker2d.py @@ -203,6 +203,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) x_velocity = ( diff --git a/brax/experimental/barkour_v0/README.md b/brax/experimental/barkour_v0/README.md deleted file mode 100644 index a2a11cdcc..000000000 --- a/brax/experimental/barkour_v0/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Google Barkour v0 Joystick Policy - -## Overview - -NOTE: For an up-to-date version of Brax training for a quadruped robot with best results, please look at the [MJX colab tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb). The environment in this folder uses an older version of the joystick policy training, which exhibited policies with high variance across seeds. - -This folder contains a training script for a flat-terrain joystick policy for the [Barkour v0 Quadruped](https://ai.googleblog.com/2023/05/barkour-benchmarking-animal-level.html) which demonstrates sim2real transfer. - -`barkour_joystick.py` contains the environment definition, while the [colab](https://colab.research.google.com/github/google/brax/blob/main/brax/experimental/barkour_v0/barkour_v0_joystick.ipynb) shows how to train the policy. - -

- -

- -## Running the environment - -We encourage the usage of the [colab](https://colab.research.google.com/github/google/brax/blob/main/brax/experimental/barkour_v0/barkour_v0_joystick.ipynb) for viewing and training policies. However, the environment can be loaded as follows: - -```python -import jax -from jax import numpy as jp - -from brax import envs -from brax.experimental.barkour_v0 import barkour_joystick - -barkour_env = envs.create('barkour_v0_joystick', backend='generalized') -``` - -And to step through the environment: - -```python -jit_env_reset = jax.jit(barkour_env.reset) -jit_env_step = jax.jit(barkour_env.step) - -state = jit_env_reset(jax.random.PRNGKey(0)) - -rollout = [] -for i in range(500): - act = jp.sin(i / 500) * jp.ones(barkour_env.sys.act_size()) - state = jit_env_step(state, act) - rollout.append(state) -``` - -## MJCF Instructions - -The MuJoCo config in `assets/barkour_v0_brax.xml` was copied from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0). The following edits were made to the MJCF specifically for brax: - -* `meshdir` was changed from `assets` to `.`. -* `frictionloss` was removed. `damping` was changed to 0.5239. -* A custom `init_qpos` was added. -* A sphere geom `lowerLegFoot` was added to all feet. All other contacts were turned off. -* The compiler option was changed to `