-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO on Pixels #560
PPO on Pixels #560
Changes from 19 commits
acc4ed7
6aafe66
9151e0c
f97358c
a99d4d7
131c325
8d6d2c4
583f7d5
b11a7dc
f848837
e899b6d
e597fb6
67db2a5
f68bc0e
30c3087
09377c4
28288c6
9438bd7
27225b2
982e142
24730bf
051c4fa
769752b
71cd073
fac85fb
40c3985
7feed21
be02919
cd7aa15
692d626
19f6337
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from brax.envs.base import PipelineEnv, State | ||
import jax | ||
from jax import numpy as jp | ||
from flax.core import FrozenDict | ||
|
||
|
||
class Fast(PipelineEnv): | ||
|
@@ -30,7 +31,9 @@ def __init__(self, **kwargs): | |
self._step_count = 0 | ||
self._use_dict_obs = kwargs.get('use_dict_obs', False) | ||
self._asymmetric_obs = kwargs.get('asymmetric_obs', False) | ||
if self._asymmetric_obs and not self._use_dict_obs: | ||
self._pixel_obs = kwargs.get('pixel_obs', False) | ||
self._pixels_only = kwargs.get('pixels_only', False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this made me scratch my head for a moment - can we instead have: self._state_obs and then raise an error if both are false? |
||
if (self._asymmetric_obs or self._pixel_obs) and not self._use_dict_obs: | ||
raise ValueError('asymmetric_obs requires use_dict_obs=True') | ||
|
||
def reset(self, rng: jax.Array) -> State: | ||
|
@@ -47,6 +50,12 @@ def reset(self, rng: jax.Array) -> State: | |
obs = {'state': obs} if self._use_dict_obs else obs | ||
if self._asymmetric_obs: | ||
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. | ||
if self._pixel_obs: | ||
pixels = dict( | ||
{f'pixels/view_{i}': jp.zeros((4, 4, 3)) for i in range(2)} | ||
) | ||
obs = pixels if self._pixels_only else {**obs, **pixels} | ||
obs = FrozenDict(obs) if self._use_dict_obs else obs | ||
reward, done = jp.array(0.0), jp.array(0.0) | ||
return State(pipeline_state, obs, reward, done) | ||
|
||
|
@@ -63,7 +72,13 @@ def step(self, state: State, action: jax.Array) -> State: | |
obs = {'state': obs} if self._use_dict_obs else obs | ||
if self._asymmetric_obs: | ||
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. | ||
if self._pixel_obs: | ||
pixels = dict( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dumb code is smart. i'd suggest dropping the "for i in range(2)" bit and just spell it out (as you did in the shape specification below) |
||
{f'pixels/view_{i}': jp.zeros((4, 4, 3)) for i in range(2)} | ||
) | ||
obs = pixels if self._pixels_only else {**obs, **pixels} | ||
reward = pos[0] | ||
obs = FrozenDict(obs) if self._use_dict_obs else obs | ||
return state.replace(pipeline_state=qp, obs=obs, reward=reward) | ||
|
||
@property | ||
|
@@ -82,6 +97,9 @@ def observation_size(self): | |
obs = {'state': 2} | ||
if self._asymmetric_obs: | ||
obs['privileged_state'] = 4 | ||
if self._pixel_obs: | ||
obs['pixels/view_0'] = (4, 4, 3) | ||
obs['pixels/view_1'] = (4, 4, 3) | ||
return obs | ||
|
||
@property | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ def wrap( | |
randomization_fn: Optional[ | ||
Callable[[System], Tuple[System, System]] | ||
] = None, | ||
scan: bool = False | ||
) -> Wrapper: | ||
"""Common wrapper pattern for all training agents. | ||
|
||
|
@@ -46,7 +47,7 @@ def wrap( | |
environment did not already have batch dimensions, it is additional Vmap | ||
wrapped. | ||
""" | ||
env = EpisodeWrapper(env, episode_length, action_repeat) | ||
env = EpisodeWrapper(env, episode_length, action_repeat, scan) | ||
if randomization_fn is None: | ||
env = VmapWrapper(env) | ||
else: | ||
|
@@ -74,10 +75,13 @@ def step(self, state: State, action: jax.Array) -> State: | |
class EpisodeWrapper(Wrapper): | ||
"""Maintains episode step count and sets done at episode end.""" | ||
|
||
def __init__(self, env: Env, episode_length: int, action_repeat: int): | ||
def __init__( | ||
self, env: Env, episode_length: int, | ||
action_repeat: int, scan: bool = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm confused by this scan param, what's it for? and why does it default to false, this seems to break behavior for people who use EpisodeWrapper? and why can i specify action_repeat > 1 but scan = False, shouldn't that be an exception? |
||
super().__init__(env) | ||
self.episode_length = episode_length | ||
self.action_repeat = action_repeat | ||
self._scan = scan | ||
|
||
def reset(self, rng: jax.Array) -> State: | ||
state = self.env.reset(rng) | ||
|
@@ -90,8 +94,12 @@ def f(state, _): | |
nstate = self.env.step(state, action) | ||
return nstate, nstate.reward | ||
|
||
state, rewards = jax.lax.scan(f, state, (), self.action_repeat) | ||
state = state.replace(reward=jp.sum(rewards, axis=0)) | ||
if self._scan: | ||
state, rewards = jax.lax.scan(f, state, (), self.action_repeat) | ||
state = state.replace(reward=jp.sum(rewards, axis=0)) | ||
else: | ||
state = self.env.step(state, action) | ||
|
||
steps = state.info['steps'] + self.action_repeat | ||
one = jp.ones_like(state.done) | ||
zero = jp.zeros_like(state.done) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,10 +124,13 @@ def update(state: RunningStatisticsState, | |
# We require exactly the same structure to avoid issues when flattened | ||
# batch and state have different order of elements. | ||
assert jax.tree_util.tree_structure(batch) == jax.tree_util.tree_structure(state.mean) | ||
batch_shape = jax.tree_util.tree_leaves(batch)[0].shape | ||
batch_leaves = jax.tree_util.tree_leaves(batch) | ||
batch_shape = batch_leaves[0].shape if batch_leaves else () | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this to handle the scenario that there's nothing to normalize, e.g. if your env only returns pixels and nothing else? it may be clearer to handle this explicitly instead of allowing the code to execute normally, e.g.: batch_leaves = ...
if not batch_leaves: # nothing to normalize
return Something # (i'll let you figure it out) |
||
# We assume the batch dimensions always go first. | ||
batch_dims = batch_shape[:len(batch_shape) - | ||
jax.tree_util.tree_leaves(state.mean)[0].ndim] | ||
batch_dims = batch_shape[ | ||
: len(batch_shape) | ||
- (jax.tree_util.tree_leaves(state.mean)[0].ndim if batch_leaves else 0) | ||
] | ||
batch_axis = range(len(batch_dims)) | ||
if weights is None: | ||
step_increment = jnp.prod(jnp.array(batch_dims)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Network implementations | ||
""" | ||
|
||
from functools import partial | ||
from typing import Any, Callable, Sequence, Tuple | ||
|
||
from flax import linen | ||
import jax | ||
import jax.numpy as jp | ||
|
||
from brax.training import networks | ||
|
||
ModuleDef = Any | ||
ActivationFn = Callable[[jp.ndarray], jp.ndarray] | ||
Initializer = Callable[..., Any] | ||
|
||
|
||
class CNN(linen.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this specific to PPO? could this go in brax/training/networks.py instead? |
||
"""CNN module. | ||
Warning: this expects the images to be 3D; convention NHWC | ||
num_filters: the number of filters per layer | ||
kernel_sizes: also per layer | ||
""" | ||
|
||
num_filters: Sequence[int] | ||
kernel_sizes: Sequence[Tuple] | ||
strides: Sequence[Tuple] | ||
activation: ActivationFn = linen.relu | ||
use_bias: bool = True | ||
|
||
@linen.compact | ||
def __call__(self, data: jp.ndarray): | ||
hidden = data | ||
for i, (num_filter, kernel_size, stride) in enumerate( | ||
zip(self.num_filters, self.kernel_sizes, self.strides) | ||
): | ||
hidden = linen.Conv( | ||
num_filter, kernel_size=kernel_size, strides=stride, use_bias=self.use_bias | ||
)(hidden) | ||
|
||
hidden = self.activation(hidden) | ||
return hidden | ||
|
||
|
||
class VisionMLP(linen.Module): | ||
# Apply a CNN backbone then an MLP. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. somewhere in here you should probably credit the paper explicitly, e.g.: Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. |
||
layer_sizes: Sequence[int] | ||
activation: ActivationFn = linen.relu | ||
kernel_init: Initializer = jax.nn.initializers.lecun_uniform() | ||
activate_final: bool = False | ||
layer_norm: bool = False | ||
normalise_channels: bool = False | ||
state_obs_key: str = "" | ||
|
||
@linen.compact | ||
def __call__(self, data: dict): | ||
if self.normalise_channels: | ||
# Calculates shared statistics over an entire 2D image. | ||
image_layernorm = partial( | ||
linen.LayerNorm, use_bias=False, use_scale=False, reduction_axes=(-1, -2) | ||
) | ||
|
||
def ln_per_chan(v: jax.Array): | ||
normalised = [image_layernorm()(v[..., chan]) for chan in range(v.shape[-1])] | ||
return jp.stack(normalised, axis=-1) | ||
|
||
pixels_hidden = { | ||
key: ln_per_chan(v) for key, v in data.items() if key.startswith("pixels/") | ||
} | ||
else: | ||
pixels_hidden = {k: v for k, v in data.items() if k.startswith("pixels/")} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: define this at the beginning of the function, then for self.normalise_channels do: pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden) |
||
|
||
natureCNN = partial( | ||
CNN, | ||
num_filters=[32, 64, 64], | ||
kernel_sizes=[(8, 8), (4, 4), (3, 3)], | ||
strides=[(4, 4), (2, 2), (1, 1)], | ||
activation=linen.relu, | ||
use_bias=False, | ||
) | ||
cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden.keys()] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't need keys(), that's the default for iterators |
||
cnn_outs = [jp.mean(cnn_out, axis=(-2, -3)) for cnn_out in cnn_outs] | ||
if self.state_obs_key: | ||
cnn_outs.append( | ||
data[self.state_obs_key] | ||
) # TODO: Try with dedicated state network | ||
|
||
hidden = jp.concatenate(cnn_outs, axis=-1) | ||
return networks.MLP( | ||
layer_sizes=self.layer_sizes, | ||
activation=self.activation, | ||
kernel_init=self.kernel_init, | ||
activate_final=self.activate_final, | ||
layer_norm=self.layer_norm, | ||
)(hidden) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
"""PPO vision networks.""" | ||
|
||
from typing import Any, Callable, Mapping, Sequence, Tuple, Union | ||
|
||
import jax | ||
import jax.numpy as jp | ||
import flax | ||
from flax import linen | ||
from flax.core import FrozenDict | ||
|
||
from brax.training import distribution | ||
from brax.training import networks | ||
from brax.training import types | ||
from brax.training.types import PRNGKey | ||
from brax.training.agents.ppo.networks_cnn import VisionMLP | ||
|
||
|
||
ModuleDef = Any | ||
ActivationFn = Callable[[jp.ndarray], jp.ndarray] | ||
Initializer = Callable[..., Any] | ||
|
||
|
||
@flax.struct.dataclass | ||
class PPONetworks: | ||
policy_network: networks.FeedForwardNetwork | ||
value_network: networks.FeedForwardNetwork | ||
parametric_action_distribution: distribution.ParametricDistribution | ||
|
||
|
||
def remove_pixels(obs: Union[jp.ndarray, Mapping]) -> Union[jp.ndarray, Mapping]: | ||
"""Remove pixel observations from the observation dict. | ||
FrozenDicts are used to avoid incorrect gradients.""" | ||
if not isinstance(obs, Mapping): | ||
return obs | ||
obs = FrozenDict(obs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems complicated - why not do as above:
|
||
pixel_keys = [k for k in obs.keys() if k.startswith("pixels/")] | ||
state_obs = obs | ||
for k in pixel_keys: | ||
state_obs, _ = state_obs.pop(k) | ||
return state_obs | ||
|
||
|
||
def make_policy_network_vision( | ||
observation_size: Mapping[str, Tuple[int, ...]], | ||
output_size: int, | ||
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, | ||
hidden_layer_sizes: Sequence[int] = [256, 256], | ||
activation: ActivationFn = linen.swish, | ||
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), | ||
layer_norm: bool = False, | ||
state_obs_key: str = "", | ||
normalise_channels: bool = False, | ||
) -> networks.FeedForwardNetwork: | ||
module = VisionMLP( | ||
layer_sizes=list(hidden_layer_sizes) + [output_size], | ||
activation=activation, | ||
kernel_init=kernel_init, | ||
layer_norm=layer_norm, | ||
normalise_channels=normalise_channels, | ||
state_obs_key=state_obs_key, | ||
) | ||
|
||
def apply(processor_params, policy_params, obs): | ||
obs = FrozenDict(obs) | ||
if state_obs_key: | ||
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not quite sure i follow this. it seems like you are doing this:
if that's the case, why not just preprocess the state obs key? |
||
obs = obs.copy({state_obs_key: state_obs[state_obs_key]}) | ||
return module.apply(policy_params, obs) | ||
|
||
dummy_obs = {key: jp.zeros((1,) + shape) for key, shape in observation_size.items()} | ||
|
||
return networks.FeedForwardNetwork( | ||
init=lambda key: module.init(key, dummy_obs), apply=apply | ||
) | ||
|
||
|
||
def make_value_network_vision( | ||
observation_size: Mapping[str, Tuple[int, ...]], | ||
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, | ||
hidden_layer_sizes: Sequence[int] = [256, 256], | ||
activation: ActivationFn = linen.swish, | ||
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), | ||
state_obs_key: str = "", | ||
normalise_channels: bool = False, | ||
) -> networks.FeedForwardNetwork: | ||
value_module = VisionMLP( | ||
layer_sizes=list(hidden_layer_sizes) + [1], | ||
activation=activation, | ||
kernel_init=kernel_init, | ||
normalise_channels=normalise_channels, | ||
state_obs_key=state_obs_key, | ||
) | ||
|
||
def apply(processor_params, policy_params, obs): | ||
obs = FrozenDict(obs) | ||
if state_obs_key: | ||
# Apply normaliser to state-based params. | ||
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params) | ||
obs = obs.copy({state_obs_key: state_obs[state_obs_key]}) | ||
return jp.squeeze(value_module.apply(policy_params, obs), axis=-1) | ||
|
||
dummy_obs = {key: jp.zeros((1,) + shape) for key, shape in observation_size.items()} | ||
return networks.FeedForwardNetwork( | ||
init=lambda key: value_module.init(key, dummy_obs), apply=apply | ||
) | ||
|
||
|
||
def make_ppo_networks_vision( | ||
# channel_size: int, | ||
observation_size: Mapping[str, Tuple[int, ...]], | ||
action_size: int, | ||
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, | ||
policy_hidden_layer_sizes: Sequence[int] = [256, 256], | ||
value_hidden_layer_sizes: Sequence[int] = [256, 256], | ||
activation: ActivationFn = linen.swish, | ||
normalise_channels: bool = False, | ||
policy_obs_key: str = "", | ||
value_obs_key: str = "", | ||
) -> PPONetworks: | ||
"""Make Vision PPO networks with preprocessor.""" | ||
|
||
parametric_action_distribution = distribution.NormalTanhDistribution( | ||
event_size=action_size | ||
) | ||
|
||
policy_network = make_policy_network_vision( | ||
observation_size=observation_size, | ||
output_size=parametric_action_distribution.param_size, | ||
preprocess_observations_fn=preprocess_observations_fn, | ||
activation=activation, | ||
hidden_layer_sizes=policy_hidden_layer_sizes, | ||
state_obs_key=policy_obs_key, | ||
normalise_channels=normalise_channels, | ||
) | ||
|
||
value_network = make_value_network_vision( | ||
observation_size=observation_size, | ||
preprocess_observations_fn=preprocess_observations_fn, | ||
activation=activation, | ||
hidden_layer_sizes=value_hidden_layer_sizes, | ||
state_obs_key=value_obs_key, | ||
normalise_channels=normalise_channels, | ||
) | ||
|
||
return PPONetworks( | ||
policy_network=policy_network, | ||
value_network=value_network, | ||
parametric_action_distribution=parametric_action_distribution, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC why are we using kwargs.get instead of just adding the param to the function signature?
better for type checking
feel free to ignore if this breaks something somewhere else