Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO on Pixels #560

Merged
merged 31 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
acc4ed7
Initial commit of vision PPO networks
StafaH Nov 22, 2024
6aafe66
implement vision wrappers
Andrew-Luo1 Nov 24, 2024
9151e0c
change ppo loss and the autoresetwrapper to support dictionary-valued…
Andrew-Luo1 Nov 24, 2024
f97358c
add random image shifts
Andrew-Luo1 Nov 24, 2024
a99d4d7
support normalising observations, clean up train_pixels.py
Andrew-Luo1 Nov 25, 2024
131c325
vision networks
Andrew-Luo1 Nov 25, 2024
8d6d2c4
fix bug in state normalisation
Andrew-Luo1 Nov 26, 2024
583f7d5
add channel-wise layer norm in CNN
Andrew-Luo1 Nov 26, 2024
b11a7dc
remove old file
Andrew-Luo1 Nov 26, 2024
f848837
clean up imports
Andrew-Luo1 Nov 26, 2024
e899b6d
enforce FrozenDict to avoid incorrect gradients
Andrew-Luo1 Nov 27, 2024
e597fb6
refactor the vision wrappers as flags in envs.training wrappers
Andrew-Luo1 Nov 27, 2024
67db2a5
support asymmetric actor critic on pixels, clean up normalisation logic
Andrew-Luo1 Dec 2, 2024
f68bc0e
rename networks files
Andrew-Luo1 Dec 2, 2024
30c3087
write basic pixels ppo test, make remove_pixels() check for non-dict obs
Andrew-Luo1 Dec 2, 2024
09377c4
update test for ppo on pixels to test pixel-only observations and cas…
Andrew-Luo1 Dec 2, 2024
28288c6
fix bug for aac on pixels
Andrew-Luo1 Dec 2, 2024
9438bd7
remove old file
Andrew-Luo1 Dec 2, 2024
27225b2
linting
Andrew-Luo1 Dec 2, 2024
982e142
clean up logic for toy testing env
Andrew-Luo1 Dec 3, 2024
24730bf
small code placement and logic clean-up
Andrew-Luo1 Dec 3, 2024
051c4fa
for vision networks, only normalize as needed
Andrew-Luo1 Dec 3, 2024
769752b
move vision networks around
Andrew-Luo1 Dec 3, 2024
71cd073
remove scan parameter for wrapping but switch wrapping order
Andrew-Luo1 Dec 3, 2024
fac85fb
linting
Andrew-Luo1 Dec 3, 2024
40c3985
add acknowledgement
Andrew-Luo1 Dec 3, 2024
7feed21
replace boolean args to testing env with obs_mode enum
Andrew-Luo1 Dec 3, 2024
be02919
write docstring for toy testing env and clean up
Andrew-Luo1 Dec 3, 2024
cd7aa15
make pixels functions private
Andrew-Luo1 Dec 4, 2024
692d626
Merge branch 'main' into vision_ppo_rebased
Andrew-Luo1 Dec 4, 2024
19f6337
update sac test
Andrew-Luo1 Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from brax.envs.base import PipelineEnv, State
import jax
from jax import numpy as jp
from flax.core import FrozenDict


class Fast(PipelineEnv):
Expand All @@ -30,7 +31,9 @@ def __init__(self, **kwargs):
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC why are we using kwargs.get instead of just adding the param to the function signature?

def __init__(self, use_dict_obs: bool = False, asymmetric_obs: bool = False, ..., **kwargs):

better for type checking

feel free to ignore if this breaks something somewhere else

self._asymmetric_obs = kwargs.get('asymmetric_obs', False)
if self._asymmetric_obs and not self._use_dict_obs:
self._pixel_obs = kwargs.get('pixel_obs', False)
self._pixels_only = kwargs.get('pixels_only', False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this made me scratch my head for a moment - can we instead have:

self._state_obs
self._pixel_obs

and then raise an error if both are false?

if (self._asymmetric_obs or self._pixel_obs) and not self._use_dict_obs:
raise ValueError('asymmetric_obs requires use_dict_obs=True')

def reset(self, rng: jax.Array) -> State:
Expand All @@ -47,6 +50,12 @@ def reset(self, rng: jax.Array) -> State:
obs = {'state': obs} if self._use_dict_obs else obs
if self._asymmetric_obs:
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state.
if self._pixel_obs:
pixels = dict(
{f'pixels/view_{i}': jp.zeros((4, 4, 3)) for i in range(2)}
)
obs = pixels if self._pixels_only else {**obs, **pixels}
obs = FrozenDict(obs) if self._use_dict_obs else obs
reward, done = jp.array(0.0), jp.array(0.0)
return State(pipeline_state, obs, reward, done)

Expand All @@ -63,7 +72,13 @@ def step(self, state: State, action: jax.Array) -> State:
obs = {'state': obs} if self._use_dict_obs else obs
if self._asymmetric_obs:
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state.
if self._pixel_obs:
pixels = dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb code is smart. i'd suggest dropping the "for i in range(2)" bit and just spell it out (as you did in the shape specification below)

{f'pixels/view_{i}': jp.zeros((4, 4, 3)) for i in range(2)}
)
obs = pixels if self._pixels_only else {**obs, **pixels}
reward = pos[0]
obs = FrozenDict(obs) if self._use_dict_obs else obs
return state.replace(pipeline_state=qp, obs=obs, reward=reward)

@property
Expand All @@ -82,6 +97,9 @@ def observation_size(self):
obs = {'state': 2}
if self._asymmetric_obs:
obs['privileged_state'] = 4
if self._pixel_obs:
obs['pixels/view_0'] = (4, 4, 3)
obs['pixels/view_1'] = (4, 4, 3)
return obs

@property
Expand Down
16 changes: 12 additions & 4 deletions brax/envs/wrappers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def wrap(
randomization_fn: Optional[
Callable[[System], Tuple[System, System]]
] = None,
scan: bool = False
) -> Wrapper:
"""Common wrapper pattern for all training agents.

Expand All @@ -46,7 +47,7 @@ def wrap(
environment did not already have batch dimensions, it is additional Vmap
wrapped.
"""
env = EpisodeWrapper(env, episode_length, action_repeat)
env = EpisodeWrapper(env, episode_length, action_repeat, scan)
if randomization_fn is None:
env = VmapWrapper(env)
else:
Expand Down Expand Up @@ -74,10 +75,13 @@ def step(self, state: State, action: jax.Array) -> State:
class EpisodeWrapper(Wrapper):
"""Maintains episode step count and sets done at episode end."""

def __init__(self, env: Env, episode_length: int, action_repeat: int):
def __init__(
self, env: Env, episode_length: int,
action_repeat: int, scan: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused by this scan param, what's it for? and why does it default to false, this seems to break behavior for people who use EpisodeWrapper? and why can i specify action_repeat > 1 but scan = False, shouldn't that be an exception?

super().__init__(env)
self.episode_length = episode_length
self.action_repeat = action_repeat
self._scan = scan

def reset(self, rng: jax.Array) -> State:
state = self.env.reset(rng)
Expand All @@ -90,8 +94,12 @@ def f(state, _):
nstate = self.env.step(state, action)
return nstate, nstate.reward

state, rewards = jax.lax.scan(f, state, (), self.action_repeat)
state = state.replace(reward=jp.sum(rewards, axis=0))
if self._scan:
state, rewards = jax.lax.scan(f, state, (), self.action_repeat)
state = state.replace(reward=jp.sum(rewards, axis=0))
else:
state = self.env.step(state, action)

steps = state.info['steps'] + self.action_repeat
one = jp.ones_like(state.done)
zero = jp.zeros_like(state.done)
Expand Down
9 changes: 6 additions & 3 deletions brax/training/acme/running_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def update(state: RunningStatisticsState,
# We require exactly the same structure to avoid issues when flattened
# batch and state have different order of elements.
assert jax.tree_util.tree_structure(batch) == jax.tree_util.tree_structure(state.mean)
batch_shape = jax.tree_util.tree_leaves(batch)[0].shape
batch_leaves = jax.tree_util.tree_leaves(batch)
batch_shape = batch_leaves[0].shape if batch_leaves else ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to handle the scenario that there's nothing to normalize, e.g. if your env only returns pixels and nothing else?

it may be clearer to handle this explicitly instead of allowing the code to execute normally, e.g.:

batch_leaves = ...
if not batch_leaves:  # nothing to normalize
  return Something # (i'll let you figure it out)

# We assume the batch dimensions always go first.
batch_dims = batch_shape[:len(batch_shape) -
jax.tree_util.tree_leaves(state.mean)[0].ndim]
batch_dims = batch_shape[
: len(batch_shape)
- (jax.tree_util.tree_leaves(state.mean)[0].ndim if batch_leaves else 0)
]
batch_axis = range(len(batch_dims))
if weights is None:
step_increment = jnp.prod(jnp.array(batch_dims))
Expand Down
96 changes: 96 additions & 0 deletions brax/training/agents/ppo/networks_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Network implementations
"""

from functools import partial
from typing import Any, Callable, Sequence, Tuple

from flax import linen
import jax
import jax.numpy as jp

from brax.training import networks

ModuleDef = Any
ActivationFn = Callable[[jp.ndarray], jp.ndarray]
Initializer = Callable[..., Any]


class CNN(linen.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this specific to PPO?

could this go in brax/training/networks.py instead?

"""CNN module.
Warning: this expects the images to be 3D; convention NHWC
num_filters: the number of filters per layer
kernel_sizes: also per layer
"""

num_filters: Sequence[int]
kernel_sizes: Sequence[Tuple]
strides: Sequence[Tuple]
activation: ActivationFn = linen.relu
use_bias: bool = True

@linen.compact
def __call__(self, data: jp.ndarray):
hidden = data
for i, (num_filter, kernel_size, stride) in enumerate(
zip(self.num_filters, self.kernel_sizes, self.strides)
):
hidden = linen.Conv(
num_filter, kernel_size=kernel_size, strides=stride, use_bias=self.use_bias
)(hidden)

hidden = self.activation(hidden)
return hidden


class VisionMLP(linen.Module):
# Apply a CNN backbone then an MLP.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somewhere in here you should probably credit the paper explicitly, e.g.:

Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533.

layer_sizes: Sequence[int]
activation: ActivationFn = linen.relu
kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
activate_final: bool = False
layer_norm: bool = False
normalise_channels: bool = False
state_obs_key: str = ""

@linen.compact
def __call__(self, data: dict):
if self.normalise_channels:
# Calculates shared statistics over an entire 2D image.
image_layernorm = partial(
linen.LayerNorm, use_bias=False, use_scale=False, reduction_axes=(-1, -2)
)

def ln_per_chan(v: jax.Array):
normalised = [image_layernorm()(v[..., chan]) for chan in range(v.shape[-1])]
return jp.stack(normalised, axis=-1)

pixels_hidden = {
key: ln_per_chan(v) for key, v in data.items() if key.startswith("pixels/")
}
else:
pixels_hidden = {k: v for k, v in data.items() if k.startswith("pixels/")}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: define this at the beginning of the function, then for self.normalise_channels do:

pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden)


natureCNN = partial(
CNN,
num_filters=[32, 64, 64],
kernel_sizes=[(8, 8), (4, 4), (3, 3)],
strides=[(4, 4), (2, 2), (1, 1)],
activation=linen.relu,
use_bias=False,
)
cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden.keys()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't need keys(), that's the default for iterators

cnn_outs = [jp.mean(cnn_out, axis=(-2, -3)) for cnn_out in cnn_outs]
if self.state_obs_key:
cnn_outs.append(
data[self.state_obs_key]
) # TODO: Try with dedicated state network

hidden = jp.concatenate(cnn_outs, axis=-1)
return networks.MLP(
layer_sizes=self.layer_sizes,
activation=self.activation,
kernel_init=self.kernel_init,
activate_final=self.activate_final,
layer_norm=self.layer_norm,
)(hidden)
149 changes: 149 additions & 0 deletions brax/training/agents/ppo/networks_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""PPO vision networks."""

from typing import Any, Callable, Mapping, Sequence, Tuple, Union

import jax
import jax.numpy as jp
import flax
from flax import linen
from flax.core import FrozenDict

from brax.training import distribution
from brax.training import networks
from brax.training import types
from brax.training.types import PRNGKey
from brax.training.agents.ppo.networks_cnn import VisionMLP


ModuleDef = Any
ActivationFn = Callable[[jp.ndarray], jp.ndarray]
Initializer = Callable[..., Any]


@flax.struct.dataclass
class PPONetworks:
policy_network: networks.FeedForwardNetwork
value_network: networks.FeedForwardNetwork
parametric_action_distribution: distribution.ParametricDistribution


def remove_pixels(obs: Union[jp.ndarray, Mapping]) -> Union[jp.ndarray, Mapping]:
"""Remove pixel observations from the observation dict.
FrozenDicts are used to avoid incorrect gradients."""
if not isinstance(obs, Mapping):
return obs
obs = FrozenDict(obs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems complicated - why not do as above:

return {k: v for k, v in obs.items() if not k.startswith("pixels/")}

pixel_keys = [k for k in obs.keys() if k.startswith("pixels/")]
state_obs = obs
for k in pixel_keys:
state_obs, _ = state_obs.pop(k)
return state_obs


def make_policy_network_vision(
observation_size: Mapping[str, Tuple[int, ...]],
output_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = [256, 256],
activation: ActivationFn = linen.swish,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False,
state_obs_key: str = "",
normalise_channels: bool = False,
) -> networks.FeedForwardNetwork:
module = VisionMLP(
layer_sizes=list(hidden_layer_sizes) + [output_size],
activation=activation,
kernel_init=kernel_init,
layer_norm=layer_norm,
normalise_channels=normalise_channels,
state_obs_key=state_obs_key,
)

def apply(processor_params, policy_params, obs):
obs = FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not quite sure i follow this. it seems like you are doing this:

  1. preprocess all non-pixel fields (i assume this is for normalization)
  2. take only a SINGLE field (the state obs key) from the result and pass it into the module

if that's the case, why not just preprocess the state obs key?

obs = obs.copy({state_obs_key: state_obs[state_obs_key]})
return module.apply(policy_params, obs)

dummy_obs = {key: jp.zeros((1,) + shape) for key, shape in observation_size.items()}

return networks.FeedForwardNetwork(
init=lambda key: module.init(key, dummy_obs), apply=apply
)


def make_value_network_vision(
observation_size: Mapping[str, Tuple[int, ...]],
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = [256, 256],
activation: ActivationFn = linen.swish,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
state_obs_key: str = "",
normalise_channels: bool = False,
) -> networks.FeedForwardNetwork:
value_module = VisionMLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=kernel_init,
normalise_channels=normalise_channels,
state_obs_key=state_obs_key,
)

def apply(processor_params, policy_params, obs):
obs = FrozenDict(obs)
if state_obs_key:
# Apply normaliser to state-based params.
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params)
obs = obs.copy({state_obs_key: state_obs[state_obs_key]})
return jp.squeeze(value_module.apply(policy_params, obs), axis=-1)

dummy_obs = {key: jp.zeros((1,) + shape) for key, shape in observation_size.items()}
return networks.FeedForwardNetwork(
init=lambda key: value_module.init(key, dummy_obs), apply=apply
)


def make_ppo_networks_vision(
# channel_size: int,
observation_size: Mapping[str, Tuple[int, ...]],
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = [256, 256],
value_hidden_layer_sizes: Sequence[int] = [256, 256],
activation: ActivationFn = linen.swish,
normalise_channels: bool = False,
policy_obs_key: str = "",
value_obs_key: str = "",
) -> PPONetworks:
"""Make Vision PPO networks with preprocessor."""

parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
)

policy_network = make_policy_network_vision(
observation_size=observation_size,
output_size=parametric_action_distribution.param_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
hidden_layer_sizes=policy_hidden_layer_sizes,
state_obs_key=policy_obs_key,
normalise_channels=normalise_channels,
)

value_network = make_value_network_vision(
observation_size=observation_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
hidden_layer_sizes=value_hidden_layer_sizes,
state_obs_key=value_obs_key,
normalise_channels=normalise_channels,
)

return PPONetworks(
policy_network=policy_network,
value_network=value_network,
parametric_action_distribution=parametric_action_distribution,
)
Loading