Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-Luo1 committed Dec 2, 2024
1 parent 9438bd7 commit e028186
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 117 deletions.
7 changes: 4 additions & 3 deletions brax/training/acme/running_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def update(state: RunningStatisticsState,
batch_leaves = jax.tree_util.tree_leaves(batch)
batch_shape = batch_leaves[0].shape if batch_leaves else ()
# We assume the batch dimensions always go first.
batch_dims = batch_shape[:len(batch_shape) -
(jax.tree_util.tree_leaves(state.mean)[0].ndim if
batch_leaves else 0)]
batch_dims = batch_shape[
: len(batch_shape)
- (jax.tree_util.tree_leaves(state.mean)[0].ndim if batch_leaves else 0)
]
batch_axis = range(len(batch_dims))
if weights is None:
step_increment = jnp.prod(jnp.array(batch_dims))
Expand Down
63 changes: 33 additions & 30 deletions brax/training/agents/ppo/networks_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CNN(linen.Module):
num_filters: the number of filters per layer
kernel_sizes: also per layer
"""

num_filters: Sequence[int]
kernel_sizes: Sequence[Tuple]
strides: Sequence[Tuple]
Expand All @@ -32,15 +33,12 @@ class CNN(linen.Module):
def __call__(self, data: jp.ndarray):
hidden = data
for i, (num_filter, kernel_size, stride) in enumerate(
zip(self.num_filters, self.kernel_sizes, self.strides)):
zip(self.num_filters, self.kernel_sizes, self.strides)
):
hidden = linen.Conv(
num_filter,
kernel_size=kernel_size,
strides=stride,
use_bias=self.use_bias)(
hidden)

num_filter, kernel_size=kernel_size, strides=stride, use_bias=self.use_bias
)(hidden)

hidden = self.activation(hidden)
return hidden

Expand All @@ -53,41 +51,46 @@ class VisionMLP(linen.Module):
activate_final: bool = False
layer_norm: bool = False
normalise_channels: bool = False
state_obs_key: str = ''

state_obs_key: str = ""

@linen.compact
def __call__(self, data: dict):
if self.normalise_channels:
# Calculates shared statistics over an entire 2D image.
image_layernorm = partial(linen.LayerNorm, use_bias=False, use_scale=False,
reduction_axes=(-1, -2))

image_layernorm = partial(
linen.LayerNorm, use_bias=False, use_scale=False, reduction_axes=(-1, -2)
)

def ln_per_chan(v: jax.Array):
normalised = [image_layernorm()(v[..., chan]) for chan in range(v.shape[-1])]
return jp.stack(normalised, axis=-1)

pixels_hidden = {
key: ln_per_chan(v) for key, v in data.items()
if key.startswith('pixels/')}
key: ln_per_chan(v) for key, v in data.items() if key.startswith("pixels/")
}
else:
pixels_hidden = {
k: v for k, v in data.items() if k.startswith('pixels/')}

natureCNN = partial(CNN,
num_filters=[32, 64, 64],
kernel_sizes=[(8, 8), (4, 4), (3, 3)],
strides=[(4, 4), (2, 2), (1, 1)],
activation=linen.relu,
use_bias=False)
pixels_hidden = {k: v for k, v in data.items() if k.startswith("pixels/")}

natureCNN = partial(
CNN,
num_filters=[32, 64, 64],
kernel_sizes=[(8, 8), (4, 4), (3, 3)],
strides=[(4, 4), (2, 2), (1, 1)],
activation=linen.relu,
use_bias=False,
)
cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden.keys()]
cnn_outs = [jp.mean(cnn_out, axis=(-2, -3)) for cnn_out in cnn_outs]
if self.state_obs_key:
cnn_outs.append(data[self.state_obs_key]) # TODO: Try with dedicated state network
cnn_outs.append(
data[self.state_obs_key]
) # TODO: Try with dedicated state network

hidden = jp.concatenate(cnn_outs, axis=-1)
return networks.MLP(layer_sizes=self.layer_sizes,
activation=self.activation,
kernel_init=self.kernel_init,
activate_final=self.activate_final,
layer_norm=self.layer_norm)(hidden)
return networks.MLP(
layer_sizes=self.layer_sizes,
activation=self.activation,
kernel_init=self.kernel_init,
activate_final=self.activate_final,
layer_norm=self.layer_norm,
)(hidden)
80 changes: 41 additions & 39 deletions brax/training/agents/ppo/networks_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def remove_pixels(obs: Union[jp.ndarray, Mapping]) -> Union[jp.ndarray, Mapping]
if not isinstance(obs, Mapping):
return obs
obs = FrozenDict(obs)
pixel_keys = [k for k in obs.keys() if k.startswith('pixels/')]
pixel_keys = [k for k in obs.keys() if k.startswith("pixels/")]
state_obs = obs
for k in pixel_keys:
state_obs, _ = state_obs.pop(k)
Expand All @@ -48,31 +48,30 @@ def make_policy_network_vision(
activation: ActivationFn = linen.swish,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False,
state_obs_key: str = '',
normalise_channels: bool = False) -> networks.FeedForwardNetwork:

state_obs_key: str = "",
normalise_channels: bool = False,
) -> networks.FeedForwardNetwork:
module = VisionMLP(
layer_sizes=list(hidden_layer_sizes) + [output_size],
activation=activation,
kernel_init=kernel_init,
layer_norm=layer_norm,
normalise_channels=normalise_channels,
state_obs_key=state_obs_key)
layer_sizes=list(hidden_layer_sizes) + [output_size],
activation=activation,
kernel_init=kernel_init,
layer_norm=layer_norm,
normalise_channels=normalise_channels,
state_obs_key=state_obs_key,
)

def apply(processor_params, policy_params, obs):
obs = FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(
remove_pixels(obs), processor_params
)
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params)
obs = obs.copy({state_obs_key: state_obs[state_obs_key]})
return module.apply(policy_params, obs)

dummy_obs = {key: jp.zeros((1,) + shape )
for key, shape in observation_size.items()}

dummy_obs = {key: jp.zeros((1,) + shape) for key, shape in observation_size.items()}

return networks.FeedForwardNetwork(
init=lambda key: module.init(key, dummy_obs), apply=apply)
init=lambda key: module.init(key, dummy_obs), apply=apply
)


def make_value_network_vision(
Expand All @@ -81,48 +80,48 @@ def make_value_network_vision(
hidden_layer_sizes: Sequence[int] = [256, 256],
activation: ActivationFn = linen.swish,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
state_obs_key: str = '',
normalise_channels: bool = False) -> networks.FeedForwardNetwork:

state_obs_key: str = "",
normalise_channels: bool = False,
) -> networks.FeedForwardNetwork:
value_module = VisionMLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=kernel_init,
normalise_channels=normalise_channels,
state_obs_key=state_obs_key)
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=kernel_init,
normalise_channels=normalise_channels,
state_obs_key=state_obs_key,
)

def apply(processor_params, policy_params, obs):
obs = FrozenDict(obs)
if state_obs_key:
# Apply normaliser to state-based params.
state_obs = preprocess_observations_fn(
remove_pixels(obs), processor_params
)
state_obs = preprocess_observations_fn(remove_pixels(obs), processor_params)
obs = obs.copy({state_obs_key: state_obs[state_obs_key]})
return jp.squeeze(value_module.apply(policy_params, obs), axis=-1)

dummy_obs = {key: jp.zeros((1,) + shape )
for key, shape in observation_size.items()}
dummy_obs = {key: jp.zeros((1,) + shape) for key, shape in observation_size.items()}
return networks.FeedForwardNetwork(
init=lambda key: value_module.init(key, dummy_obs), apply=apply)
init=lambda key: value_module.init(key, dummy_obs), apply=apply
)


def make_ppo_networks_vision(
# channel_size: int,
observation_size: Mapping[str, Tuple[int, ...]],
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = [256, 256],
value_hidden_layer_sizes: Sequence[int] = [256, 256],
activation: ActivationFn = linen.swish,
normalise_channels: bool = False,
policy_obs_key: str = '',
value_obs_key: str = '') -> PPONetworks:
policy_obs_key: str = "",
value_obs_key: str = "",
) -> PPONetworks:
"""Make Vision PPO networks with preprocessor."""

parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
event_size=action_size
)

policy_network = make_policy_network_vision(
observation_size=observation_size,
Expand All @@ -131,17 +130,20 @@ def make_ppo_networks_vision(
activation=activation,
hidden_layer_sizes=policy_hidden_layer_sizes,
state_obs_key=policy_obs_key,
normalise_channels=normalise_channels)
normalise_channels=normalise_channels,
)

value_network = make_value_network_vision(
observation_size=observation_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
hidden_layer_sizes=value_hidden_layer_sizes,
state_obs_key=value_obs_key,
normalise_channels=normalise_channels)
normalise_channels=normalise_channels,
)

return PPONetworks(
policy_network=policy_network,
value_network=value_network,
parametric_action_distribution=parametric_action_distribution)
parametric_action_distribution=parametric_action_distribution,
)
23 changes: 13 additions & 10 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,24 +279,26 @@ def train(

def random_translate_pixels(obs: Mapping[str, jax.Array], key):
obs = FrozenDict(obs)

@jax.vmap
def rt_all_views(ub_obs: Mapping[str, jax.Array], key) -> Mapping[str, jax.Array]:
# Expects dictionary of unbatched observations.
def rt_view(key, img: jax.Array, padding) -> jax.Array: # TxHxWxC
# Expects dictionary of unbatched observations.
def rt_view(key, img: jax.Array, padding) -> jax.Array: # TxHxWxC
# Randomly translates a set of pixel inputs.
crop_from = jax.random.randint(key, (2, ), 0, 2 * padding + 1)
zero = jnp.zeros((1, ), dtype=jnp.int32)
crop_from = jnp.concatenate([zero, crop_from, zero])
padded_img = jnp.pad(img, ((0, 0), (padding, padding), (padding, padding), (0, 0)),
mode='edge')
crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
zero = jnp.zeros((1,), dtype=jnp.int32)
crop_from = jnp.concatenate([zero, crop_from, zero])
padded_img = jnp.pad(
img, ((0, 0), (padding, padding), (padding, padding), (0, 0)), mode="edge"
)
return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)

out = {}
for k_view, v_view in ub_obs.items():
if k_view.startswith('pixels/'):
if k_view.startswith("pixels/"):
key, key_shift = jax.random.split(key)
out[k_view] = rt_view(key_shift, v_view, 4)
ub_obs = ub_obs.copy(out) # Update the shifted fields
ub_obs = ub_obs.copy(out) # Update the shifted fields
return ub_obs

bdim = next(iter(obs.items()), None)[1].shape[0]
Expand Down Expand Up @@ -385,7 +387,8 @@ def f(carry, unused_t):
normalizer_params = running_statistics.update(
training_state.normalizer_params,
remove_pixels(data.observation),
pmap_axis_name=_PMAP_AXIS_NAME)
pmap_axis_name=_PMAP_AXIS_NAME
)

(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(
Expand Down
Loading

0 comments on commit e028186

Please sign in to comment.