Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO on Pixels #560

Merged
merged 31 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
acc4ed7
Initial commit of vision PPO networks
StafaH Nov 22, 2024
6aafe66
implement vision wrappers
Andrew-Luo1 Nov 24, 2024
9151e0c
change ppo loss and the autoresetwrapper to support dictionary-valued…
Andrew-Luo1 Nov 24, 2024
f97358c
add random image shifts
Andrew-Luo1 Nov 24, 2024
a99d4d7
support normalising observations, clean up train_pixels.py
Andrew-Luo1 Nov 25, 2024
131c325
vision networks
Andrew-Luo1 Nov 25, 2024
8d6d2c4
fix bug in state normalisation
Andrew-Luo1 Nov 26, 2024
583f7d5
add channel-wise layer norm in CNN
Andrew-Luo1 Nov 26, 2024
b11a7dc
remove old file
Andrew-Luo1 Nov 26, 2024
f848837
clean up imports
Andrew-Luo1 Nov 26, 2024
e899b6d
enforce FrozenDict to avoid incorrect gradients
Andrew-Luo1 Nov 27, 2024
e597fb6
refactor the vision wrappers as flags in envs.training wrappers
Andrew-Luo1 Nov 27, 2024
67db2a5
support asymmetric actor critic on pixels, clean up normalisation logic
Andrew-Luo1 Dec 2, 2024
f68bc0e
rename networks files
Andrew-Luo1 Dec 2, 2024
30c3087
write basic pixels ppo test, make remove_pixels() check for non-dict obs
Andrew-Luo1 Dec 2, 2024
09377c4
update test for ppo on pixels to test pixel-only observations and cas…
Andrew-Luo1 Dec 2, 2024
28288c6
fix bug for aac on pixels
Andrew-Luo1 Dec 2, 2024
9438bd7
remove old file
Andrew-Luo1 Dec 2, 2024
27225b2
linting
Andrew-Luo1 Dec 2, 2024
982e142
clean up logic for toy testing env
Andrew-Luo1 Dec 3, 2024
24730bf
small code placement and logic clean-up
Andrew-Luo1 Dec 3, 2024
051c4fa
for vision networks, only normalize as needed
Andrew-Luo1 Dec 3, 2024
769752b
move vision networks around
Andrew-Luo1 Dec 3, 2024
71cd073
remove scan parameter for wrapping but switch wrapping order
Andrew-Luo1 Dec 3, 2024
fac85fb
linting
Andrew-Luo1 Dec 3, 2024
40c3985
add acknowledgement
Andrew-Luo1 Dec 3, 2024
7feed21
replace boolean args to testing env with obs_mode enum
Andrew-Luo1 Dec 3, 2024
be02919
write docstring for toy testing env and clean up
Andrew-Luo1 Dec 3, 2024
cd7aa15
make pixels functions private
Andrew-Luo1 Dec 4, 2024
692d626
Merge branch 'main' into vision_ppo_rebased
Andrew-Luo1 Dec 4, 2024
19f6337
update sac test
Andrew-Luo1 Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@


class ObservationMode(Enum):
"""
Describes observation formats.

Attributes:
NDARRAY: Flat NumPy array of state info.
DICT_STATE: Dictionary of state info.
DICT_PIXELS: Dictionary of pixel observations.
DICT_PIXELS_STATE: Dictionary of both state and pixel info.
"""
NDARRAY = "ndarray"
DICT_STATE = "dict_state"
DICT_PIXELS = "dict_pixels"
Expand Down Expand Up @@ -122,9 +131,8 @@ def observation_size(self):
ret = super().observation_size
if self._obs_mode == ObservationMode.NDARRAY:
return ret
else:
# Turn 1-D tuples to ints.
return {key: value[0] if len(value) == 1 else value for key, value in ret.items()}
# Turn 1-D tuples to ints.
return {key: value[0] if len(value) == 1 else value for key, value in ret.items()}

@property
def action_size(self):
Expand Down
84 changes: 45 additions & 39 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,46 @@ def f(leaf):


def remove_pixels(obs: Union[jnp.ndarray, Mapping]) -> Union[jnp.ndarray, Mapping]:
"""Remove pixel observations from the observation dict.
"""Removes pixel observations from the observation dict.
FrozenDicts are used to avoid incorrect gradients."""
if not isinstance(obs, Mapping):
return obs
return FrozenDict({k: v for k, v in obs.items() if not k.startswith("pixels/")})


def random_translate_pixels(obs: Mapping[str, jax.Array], key: PRNGKey):
"""Apply random translations to B x T x ... pixel observations.
The same shift is applied across the unroll_length (T) dimension."""
obs = FrozenDict(obs)

@jax.vmap
def rt_all_views(ub_obs: Mapping[str, jax.Array], key: PRNGKey) -> Mapping[str, jax.Array]:
# Expects dictionary of unbatched observations.
def rt_view(img: jax.Array, padding: int, key: PRNGKey) -> jax.Array: # TxHxWxC
# Randomly translates a set of pixel inputs.
# Adapted from https://github.com/ikostrikov/jaxrl/blob/main/jaxrl/agents/drq/augmentations.py
crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
zero = jnp.zeros((1,), dtype=jnp.int32)
crop_from = jnp.concatenate([zero, crop_from, zero])
padded_img = jnp.pad(
img, ((0, 0), (padding, padding), (padding, padding), (0, 0)), mode="edge"
)
return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)

out = {}
for k_view, v_view in ub_obs.items():
if k_view.startswith("pixels/"):
key, key_shift = jax.random.split(key)
out[k_view] = rt_view(v_view, 4, key_shift)
ub_obs = ub_obs.copy(out) # Update the shifted fields
return ub_obs

bdim = next(iter(obs.items()), None)[1].shape[0]
keys = jax.random.split(key, bdim)
obs = rt_all_views(obs, keys)
return obs


def train(
environment: Union[envs_v1.Env, envs.Env],
num_timesteps: int,
Expand Down Expand Up @@ -176,9 +209,12 @@ def train(
Tuple of (make_policy function, network params, metrics)
"""
if madrona_backend:
assert not eval_env, "Madrona-MJX doesn't support multiple env instances"
assert num_eval_envs == num_envs, "Madrona-MJX requires a fixed batch size"
assert action_repeat == 1, "Implement action_repeat using PipelineEnv's _n_frames to avoid unnecessary rendering!"
if eval_env:
raise ValueError("Madrona-MJX doesn't support multiple env instances")
if num_eval_envs != num_envs:
raise ValueError("Madrona-MJX requires a fixed batch size")
if action_repeat != 1:
raise ValueError("Implement action_repeat using PipelineEnv's _n_frames to avoid unnecessary rendering!")

assert batch_size * num_minibatches % num_envs == 0
xt = time.time()
Expand Down Expand Up @@ -283,36 +319,6 @@ def train(
gradient_update_fn = gradients.gradient_update_fn(
loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)

def random_translate_pixels(obs: Mapping[str, jax.Array], key):
obs = FrozenDict(obs)

@jax.vmap
def rt_all_views(ub_obs: Mapping[str, jax.Array], key) -> Mapping[str, jax.Array]:
# Expects dictionary of unbatched observations.
def rt_view(key, img: jax.Array, padding) -> jax.Array: # TxHxWxC
# Randomly translates a set of pixel inputs.
# Adapted from https://github.com/ikostrikov/jaxrl/blob/main/jaxrl/agents/drq/augmentations.py
crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
zero = jnp.zeros((1,), dtype=jnp.int32)
crop_from = jnp.concatenate([zero, crop_from, zero])
padded_img = jnp.pad(
img, ((0, 0), (padding, padding), (padding, padding), (0, 0)), mode="edge"
)
return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)

out = {}
for k_view, v_view in ub_obs.items():
if k_view.startswith("pixels/"):
key, key_shift = jax.random.split(key)
out[k_view] = rt_view(key_shift, v_view, 4)
ub_obs = ub_obs.copy(out) # Update the shifted fields
return ub_obs

bdim = next(iter(obs.items()), None)[1].shape[0]
keys = jax.random.split(key, bdim)
obs = rt_all_views(obs, keys)
return obs

def minibatch_step(
carry, data: types.Transition,
normalizer_params: running_statistics.RunningStatisticsState):
Expand All @@ -332,11 +338,6 @@ def sgd_step(carry, unused_t, data: types.Transition,
optimizer_state, params, key = carry
key, key_perm, key_grad = jax.random.split(key, 3)

def convert_data(x: jnp.ndarray):
x = jax.random.permutation(key_perm, x)
x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
return x

if augment_pixels:
key, key_rt = jax.random.split(key)
r_translate = functools.partial(random_translate_pixels, key=key_rt)
Expand All @@ -349,6 +350,11 @@ def convert_data(x: jnp.ndarray):
extras=data.extras
)

def convert_data(x: jnp.ndarray):
x = jax.random.permutation(key_perm, x)
x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
return x

shuffled_data = jax.tree_util.tree_map(convert_data, data)
(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(minibatch_step, normalizer_params=normalizer_params),
Expand Down