-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO on Pixels #560
PPO on Pixels #560
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Thanks Andrew for the PR, will review it more thoroughly once the upstream PR is submitted. One initial high level thought is that the wrapper should go in https://github.com/google/brax/blob/main/brax/envs/wrappers/training.py |
return jax.tree_util.tree_map(f, tree) | ||
|
||
|
||
def train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forking this file is going to make maintenance harder in the long run. Are there a minimal set of changes you can merge back into the main train.py
file? I'm curious why you decided to do a hard fork considering that most of the file looks almost the same as train.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that we might refine the pixel specific stuff more later on. But by that point, we might as well make a new folder. I’ve merged this into the main train.py.
df16198
to
7e5ef39
Compare
I've just rebased this PR on top of the dict obs PR and done the high-level refactoring just mentioned. It'd be great if you could take a look! |
brax/training/types.py
Outdated
@@ -30,7 +30,8 @@ | |||
Params = Any | |||
PRNGKey = jnp.ndarray | |||
Metrics = Mapping[str, jnp.ndarray] | |||
Observation = jnp.ndarray | |||
Observation = Union[jnp.ndarray, Mapping[str, jnp.ndarray]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please rebase to clean up the diff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe call these networks_vision.py
, networks_cnn.py
, networks_*.py
etc.
Warning: this expects the images to be 3D; convention NHWC | ||
num_filters: the number of filters per layer | ||
kernel_sizes: also per layer | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please consider running black on your PR
…t to frozen dict (does not decrease performance)
7e5ef39
to
e028186
Compare
e028186
to
27225b2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great start!
brax/envs/fast.py
Outdated
@@ -30,7 +31,9 @@ def __init__(self, **kwargs): | |||
self._step_count = 0 | |||
self._use_dict_obs = kwargs.get('use_dict_obs', False) | |||
self._asymmetric_obs = kwargs.get('asymmetric_obs', False) | |||
if self._asymmetric_obs and not self._use_dict_obs: | |||
self._pixel_obs = kwargs.get('pixel_obs', False) | |||
self._pixels_only = kwargs.get('pixels_only', False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this made me scratch my head for a moment - can we instead have:
self._state_obs
self._pixel_obs
and then raise an error if both are false?
brax/envs/fast.py
Outdated
@@ -63,7 +72,13 @@ def step(self, state: State, action: jax.Array) -> State: | |||
obs = {'state': obs} if self._use_dict_obs else obs | |||
if self._asymmetric_obs: | |||
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. | |||
if self._pixel_obs: | |||
pixels = dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumb code is smart. i'd suggest dropping the "for i in range(2)" bit and just spell it out (as you did in the shape specification below)
brax/envs/fast.py
Outdated
@@ -30,7 +31,9 @@ def __init__(self, **kwargs): | |||
self._step_count = 0 | |||
self._use_dict_obs = kwargs.get('use_dict_obs', False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC why are we using kwargs.get instead of just adding the param to the function signature?
def __init__(self, use_dict_obs: bool = False, asymmetric_obs: bool = False, ..., **kwargs):
better for type checking
feel free to ignore if this breaks something somewhere else
brax/envs/wrappers/training.py
Outdated
def __init__(self, env: Env, episode_length: int, action_repeat: int): | ||
def __init__( | ||
self, env: Env, episode_length: int, | ||
action_repeat: int, scan: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm confused by this scan param, what's it for? and why does it default to false, this seems to break behavior for people who use EpisodeWrapper? and why can i specify action_repeat > 1 but scan = False, shouldn't that be an exception?
@@ -124,10 +124,13 @@ def update(state: RunningStatisticsState, | |||
# We require exactly the same structure to avoid issues when flattened | |||
# batch and state have different order of elements. | |||
assert jax.tree_util.tree_structure(batch) == jax.tree_util.tree_structure(state.mean) | |||
batch_shape = jax.tree_util.tree_leaves(batch)[0].shape | |||
batch_leaves = jax.tree_util.tree_leaves(batch) | |||
batch_shape = batch_leaves[0].shape if batch_leaves else () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this to handle the scenario that there's nothing to normalize, e.g. if your env only returns pixels and nothing else?
it may be clearer to handle this explicitly instead of allowing the code to execute normally, e.g.:
batch_leaves = ...
if not batch_leaves: # nothing to normalize
return Something # (i'll let you figure it out)
activation=linen.relu, | ||
use_bias=False, | ||
) | ||
cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden.keys()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: don't need keys(), that's the default for iterators
key: ln_per_chan(v) for key, v in data.items() if key.startswith("pixels/") | ||
} | ||
else: | ||
pixels_hidden = {k: v for k, v in data.items() if k.startswith("pixels/")} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: define this at the beginning of the function, then for self.normalise_channels do:
pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden)
FrozenDicts are used to avoid incorrect gradients.""" | ||
if not isinstance(obs, Mapping): | ||
return obs | ||
obs = FrozenDict(obs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems complicated - why not do as above:
return {k: v for k, v in obs.items() if not k.startswith("pixels/")}
def apply(processor_params, policy_params, obs): | ||
obs = FrozenDict(obs) | ||
if state_obs_key: | ||
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not quite sure i follow this. it seems like you are doing this:
- preprocess all non-pixel fields (i assume this is for normalization)
- take only a SINGLE field (the state obs key) from the result and pass it into the module
if that's the case, why not just preprocess the state obs key?
brax/training/agents/ppo/train.py
Outdated
@@ -218,14 +226,16 @@ def train( | |||
randomization_fn, rng=randomization_rng | |||
) | |||
if isinstance(environment, envs.Env): | |||
wrap_for_training = envs.training.wrap | |||
wrap_for_training = functools.partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah i see
did you try instead changing the order of the wrapping in wrap_for_training such that VmapWrapper is created first? Then you might not need this scan param
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!
63deb4f
to
fac85fb
Compare
brax/envs/fast.py
Outdated
use_dict_obs: bool = False, | ||
asymmetric_obs: bool = False, | ||
pixel_obs: bool = False, | ||
state_obs: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these bools are getting a bit gnarly, can (use_dict_obs, pixel_obs and state_obs) be a single enum? From the user perspective, we could pass a string, and convert to enum as input validation within the class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latest changes look great! After you address @btaba 's last comment I'd say we are good to go.
brax/envs/fast.py
Outdated
ret = super().observation_size | ||
if self._obs_mode == ObservationMode.NDARRAY: | ||
return ret | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you don't need the else here, since you return early above
brax/training/agents/ppo/train.py
Outdated
@@ -164,6 +175,11 @@ def train( | |||
Returns: | |||
Tuple of (make_policy function, network params, metrics) | |||
""" | |||
if madrona_backend: | |||
assert not eval_env, "Madrona-MJX doesn't support multiple env instances" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if cond:
raise AssertionError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe even raise ValueError(....
brax/training/agents/ppo/train.py
Outdated
@@ -267,6 +283,36 @@ def train( | |||
gradient_update_fn = gradients.gradient_update_fn( | |||
loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) | |||
|
|||
def random_translate_pixels(obs: Mapping[str, jax.Array], key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: type annotation on key
also, does this function need to live inside train()
or can it be moved outside?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be worth moving to an image_augmentations.py so we can also add more agumentations in the future to improve vision RL (color shifts etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do that in the future after accumulating a few more useful transformations :)
if randomization_fn is None: | ||
env = VmapWrapper(env) | ||
else: | ||
env = DomainRandomizationVmapWrapper(env, randomization_fn) | ||
env = EpisodeWrapper(env, episode_length, action_repeat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just OOC why is this move needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something about how we implemented the batching rules that we use to call Madrona means that vmap(scan(render()))
works but scan(vmap(render()))
fails.
When we have some time we will fix this in madrona_mjx
but for now it's easier to fix it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Madrona Batch Renderer is sensitive to the scan inside the EpisodeWrapper. It breaks when you do vmap(scan(env.step)) but is ok for scan(vmap(env.step)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change might not be needed anymore, since Madrona requires it's own vmap wrapper to be used:
https://github.com/shacklettbp/madrona_mjx/blob/main/src/madrona_mjx/wrapper.py
wrap_env must be False when using vision, so that users know to use the MadronaWrapper, and we can include instructions over there for how to correctly order the wraps of the environment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. I would be OK with merging this as is and then adopting the Madrona wrapper as a separate PR.
Updated via feedback! |
655e630
to
cd7aa15
Compare
@Andrew-Luo1 can you follow the process to resolve conflicts, and then we will merge? |
Conflicts resolved! |
Let's try again? The SAC test was failing because it previously expected 2 resets, not including from the env.observation_size call since the previous fast env returned a hard-coded observation size. However, the fast env now calls super().observation_size, which adds an additional reset. |
Results
The PPO + networks setup introduced in this PR stably trains a pixel + prioprioceptive pick-up-cube task within 20 minutes and the pixel-based cartpole within a minute.
2024_11_26-visual_policy.mp4
original_cartpole.mp4
Algorithm
Usage
Note that the proposed changes only support dictionary-valued observations. The pixel observations are expected to be under
obs['pixels/<name>']
, whereas any additional state-based observations are to be underobs['state']
.Image augmentation
The key algorithmic addition is the option to randomly shift pixel data; toggled by the
augment_image
flag indef train
. This implementation is a PPO adaptation of DRQ in the basic case ofK=M=1
(paper section 3.3). This simple change makes a significant improvement in the above franka pick-cube task, while not significantly affecting performance in cartpole balance.Networks
This PR provides a
VisionMLP
- a basic network for pixel-based policies that combines a CNN encoder (the minimal NatureCNN architecture) and hidden layers afterwards. This network is based on work by @StafaH and has three main features:normalise_channels
) to independently apply layer norm to each independent pixel channel. This has been found to be useful when RGB and depth inputs are stacked for each camera view, since even when channels are normalised to [0, 1], depth inputs can be consistently larger in the case that most of the scene is background.Wrapping
For pixel-based observations, the PR enforces that action_repeat = 1.
The EpisodeWrapper ](https://github.com/google/brax/blob/main/brax/envs/wrappers/training.py#L93) scans over the entire env.step
action_repeat
times, generatingaction_repeat - 1
unnecessary observations:In the case of pixel-based training, the additional overhead of unnecessary render calls becomes significant.
Note that Brax offers a mostly equivalent mechanism of having each unroll see more simulated time without generating more RL transitions: the under-the-hood scan of mjx.step.
If we call EpisodeWrapper's action repeat the$LAR * UAR$ in the unwrapped environment.
upper action repeat
(UAR), and the mjx.step scan thelower action repeat
(LAR), we can simply discard the UAR. and set the LAR toLet$l_e$ and $l_e'$ be the episode lengths for EpisodeWrapper and VisionEpisodeWrapper and $T$ and $T'$ similarly for the total training steps. To ensure the same number of RL transitions and simulated time, set $l_e' = l_e / UAR$ and $T' = T/UAR$ .
In practice, wrapping in this way makes a significant performance improvment. On the above franka cube task, the raw sim throughput from looping
jit_step
goes from 99.8k to 173.9k simulation steps/second. The wall clock for the equivalent amount of simulated time and RL transitions goes from 25m:40s to 19m:28s.