Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO on Pixels #560

Merged
merged 31 commits into from
Dec 4, 2024
Merged

PPO on Pixels #560

merged 31 commits into from
Dec 4, 2024

Conversation

Andrew-Luo1
Copy link
Contributor

@Andrew-Luo1 Andrew-Luo1 commented Nov 26, 2024

Results
The PPO + networks setup introduced in this PR stably trains a pixel + prioprioceptive pick-up-cube task within 20 minutes and the pixel-based cartpole within a minute.

2024_11_26-visual_policy.mp4
original_cartpole.mp4

Algorithm

Usage
Note that the proposed changes only support dictionary-valued observations. The pixel observations are expected to be under obs['pixels/<name>'], whereas any additional state-based observations are to be under obs['state'].

Image augmentation
The key algorithmic addition is the option to randomly shift pixel data; toggled by the augment_image flag in def train. This implementation is a PPO adaptation of DRQ in the basic case of K=M=1 (paper section 3.3). This simple change makes a significant improvement in the above franka pick-cube task, while not significantly affecting performance in cartpole balance.

Pasted image 20241122080755

Networks

This PR provides a VisionMLP - a basic network for pixel-based policies that combines a CNN encoder (the minimal NatureCNN architecture) and hidden layers afterwards. This network is based on work by @StafaH and has three main features:

  1. Support for multiple pixel streams, for example assigning a separate CNN to each camera views.
  2. The option (normalise_channels) to independently apply layer norm to each independent pixel channel. This has been found to be useful when RGB and depth inputs are stacked for each camera view, since even when channels are normalised to [0, 1], depth inputs can be consistently larger in the case that most of the scene is background.
  3. Appending any state-based observations to the CNN output encodings.

Wrapping

For pixel-based observations, the PR enforces that action_repeat = 1.

The EpisodeWrapper ](https://github.com/google/brax/blob/main/brax/envs/wrappers/training.py#L93) scans over the entire env.step action_repeat times, generating action_repeat - 1 unnecessary observations:

    def f(state, _):
      nstate = self.env.step(state, action)
      return nstate, nstate.reward

    state, rewards = jax.lax.scan(f, state, (), self.action_repeat)
    state = state.replace(reward=jp.sum(rewards, axis=0))

In the case of pixel-based training, the additional overhead of unnecessary render calls becomes significant.

Note that Brax offers a mostly equivalent mechanism of having each unroll see more simulated time without generating more RL transitions: the under-the-hood scan of mjx.step.

If we call EpisodeWrapper's action repeat the upper action repeat (UAR), and the mjx.step scan the lower action repeat (LAR), we can simply discard the UAR. and set the LAR to $LAR * UAR$ in the unwrapped environment.

Let $l_e$ and $l_e'$ be the episode lengths for EpisodeWrapper and VisionEpisodeWrapper and $T$ and $T'$ similarly for the total training steps. To ensure the same number of RL transitions and simulated time, set $l_e' = l_e / UAR$ and $T' = T/UAR$.

In practice, wrapping in this way makes a significant performance improvment. On the above franka cube task, the raw sim throughput from looping jit_step goes from 99.8k to 173.9k simulation steps/second. The wall clock for the equivalent amount of simulated time and RL transitions goes from 25m:40s to 19m:28s.

Copy link

google-cla bot commented Nov 26, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@btaba
Copy link
Collaborator

btaba commented Nov 26, 2024

Thanks Andrew for the PR, will review it more thoroughly once the upstream PR is submitted.

One initial high level thought is that the wrapper should go in https://github.com/google/brax/blob/main/brax/envs/wrappers/training.py

return jax.tree_util.tree_map(f, tree)


def train(
Copy link
Collaborator

@btaba btaba Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forking this file is going to make maintenance harder in the long run. Are there a minimal set of changes you can merge back into the main train.py file? I'm curious why you decided to do a hard fork considering that most of the file looks almost the same as train.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that we might refine the pixel specific stuff more later on. But by that point, we might as well make a new folder. I’ve merged this into the main train.py.

@btaba btaba self-assigned this Nov 26, 2024
@Andrew-Luo1 Andrew-Luo1 force-pushed the vision_ppo_rebased branch 2 times, most recently from df16198 to 7e5ef39 Compare November 27, 2024 23:10
@Andrew-Luo1
Copy link
Contributor Author

I've just rebased this PR on top of the dict obs PR and done the high-level refactoring just mentioned. It'd be great if you could take a look!

@@ -30,7 +30,8 @@
Params = Any
PRNGKey = jnp.ndarray
Metrics = Mapping[str, jnp.ndarray]
Observation = jnp.ndarray
Observation = Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please rebase to clean up the diff?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe call these networks_vision.py, networks_cnn.py, networks_*.py etc.

Warning: this expects the images to be 3D; convention NHWC
num_filters: the number of filters per layer
kernel_sizes: also per layer
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please consider running black on your PR

Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great start!

@@ -30,7 +31,9 @@ def __init__(self, **kwargs):
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)
self._asymmetric_obs = kwargs.get('asymmetric_obs', False)
if self._asymmetric_obs and not self._use_dict_obs:
self._pixel_obs = kwargs.get('pixel_obs', False)
self._pixels_only = kwargs.get('pixels_only', False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this made me scratch my head for a moment - can we instead have:

self._state_obs
self._pixel_obs

and then raise an error if both are false?

@@ -63,7 +72,13 @@ def step(self, state: State, action: jax.Array) -> State:
obs = {'state': obs} if self._use_dict_obs else obs
if self._asymmetric_obs:
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state.
if self._pixel_obs:
pixels = dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb code is smart. i'd suggest dropping the "for i in range(2)" bit and just spell it out (as you did in the shape specification below)

@@ -30,7 +31,9 @@ def __init__(self, **kwargs):
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC why are we using kwargs.get instead of just adding the param to the function signature?

def __init__(self, use_dict_obs: bool = False, asymmetric_obs: bool = False, ..., **kwargs):

better for type checking

feel free to ignore if this breaks something somewhere else

def __init__(self, env: Env, episode_length: int, action_repeat: int):
def __init__(
self, env: Env, episode_length: int,
action_repeat: int, scan: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused by this scan param, what's it for? and why does it default to false, this seems to break behavior for people who use EpisodeWrapper? and why can i specify action_repeat > 1 but scan = False, shouldn't that be an exception?

@@ -124,10 +124,13 @@ def update(state: RunningStatisticsState,
# We require exactly the same structure to avoid issues when flattened
# batch and state have different order of elements.
assert jax.tree_util.tree_structure(batch) == jax.tree_util.tree_structure(state.mean)
batch_shape = jax.tree_util.tree_leaves(batch)[0].shape
batch_leaves = jax.tree_util.tree_leaves(batch)
batch_shape = batch_leaves[0].shape if batch_leaves else ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to handle the scenario that there's nothing to normalize, e.g. if your env only returns pixels and nothing else?

it may be clearer to handle this explicitly instead of allowing the code to execute normally, e.g.:

batch_leaves = ...
if not batch_leaves:  # nothing to normalize
  return Something # (i'll let you figure it out)

activation=linen.relu,
use_bias=False,
)
cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden.keys()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't need keys(), that's the default for iterators

key: ln_per_chan(v) for key, v in data.items() if key.startswith("pixels/")
}
else:
pixels_hidden = {k: v for k, v in data.items() if k.startswith("pixels/")}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: define this at the beginning of the function, then for self.normalise_channels do:

pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden)

FrozenDicts are used to avoid incorrect gradients."""
if not isinstance(obs, Mapping):
return obs
obs = FrozenDict(obs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems complicated - why not do as above:

return {k: v for k, v in obs.items() if not k.startswith("pixels/")}

def apply(processor_params, policy_params, obs):
obs = FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not quite sure i follow this. it seems like you are doing this:

  1. preprocess all non-pixel fields (i assume this is for normalization)
  2. take only a SINGLE field (the state obs key) from the result and pass it into the module

if that's the case, why not just preprocess the state obs key?

@@ -218,14 +226,16 @@ def train(
randomization_fn, rng=randomization_rng
)
if isinstance(environment, envs.Env):
wrap_for_training = envs.training.wrap
wrap_for_training = functools.partial(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see

did you try instead changing the order of the wrapping in wrap_for_training such that VmapWrapper is created first? Then you might not need this scan param

Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!

use_dict_obs: bool = False,
asymmetric_obs: bool = False,
pixel_obs: bool = False,
state_obs: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these bools are getting a bit gnarly, can (use_dict_obs, pixel_obs and state_obs) be a single enum? From the user perspective, we could pass a string, and convert to enum as input validation within the class

Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest changes look great! After you address @btaba 's last comment I'd say we are good to go.

ret = super().observation_size
if self._obs_mode == ObservationMode.NDARRAY:
return ret
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you don't need the else here, since you return early above

@@ -164,6 +175,11 @@ def train(
Returns:
Tuple of (make_policy function, network params, metrics)
"""
if madrona_backend:
assert not eval_env, "Madrona-MJX doesn't support multiple env instances"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

if cond:
  raise AssertionError

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe even raise ValueError(....

@@ -267,6 +283,36 @@ def train(
gradient_update_fn = gradients.gradient_update_fn(
loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)

def random_translate_pixels(obs: Mapping[str, jax.Array], key):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: type annotation on key

also, does this function need to live inside train() or can it be moved outside?

Copy link
Contributor

@StafaH StafaH Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth moving to an image_augmentations.py so we can also add more agumentations in the future to improve vision RL (color shifts etc.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do that in the future after accumulating a few more useful transformations :)

if randomization_fn is None:
env = VmapWrapper(env)
else:
env = DomainRandomizationVmapWrapper(env, randomization_fn)
env = EpisodeWrapper(env, episode_length, action_repeat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just OOC why is this move needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something about how we implemented the batching rules that we use to call Madrona means that vmap(scan(render())) works but scan(vmap(render())) fails.

When we have some time we will fix this in madrona_mjx but for now it's easier to fix it here.

Copy link
Contributor Author

@Andrew-Luo1 Andrew-Luo1 Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Madrona Batch Renderer is sensitive to the scan inside the EpisodeWrapper. It breaks when you do vmap(scan(env.step)) but is ok for scan(vmap(env.step)).

Copy link
Contributor

@StafaH StafaH Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change might not be needed anymore, since Madrona requires it's own vmap wrapper to be used:
https://github.com/shacklettbp/madrona_mjx/blob/main/src/madrona_mjx/wrapper.py

wrap_env must be False when using vision, so that users know to use the MadronaWrapper, and we can include instructions over there for how to correctly order the wraps of the environment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I would be OK with merging this as is and then adopting the Madrona wrapper as a separate PR.

@Andrew-Luo1
Copy link
Contributor Author

Updated via feedback!

@erikfrey
Copy link
Collaborator

erikfrey commented Dec 4, 2024

@Andrew-Luo1 can you follow the process to resolve conflicts, and then we will merge?

@Andrew-Luo1
Copy link
Contributor Author

Conflicts resolved!

@Andrew-Luo1
Copy link
Contributor Author

Andrew-Luo1 commented Dec 4, 2024

Let's try again? The SAC test was failing because it previously expected 2 resets, not including from the env.observation_size call since the previous fast env returned a hard-coded observation size. However, the fast env now calls super().observation_size, which adds an additional reset.

@btaba btaba merged commit 68906bc into google:main Dec 4, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants