From 146af049380e71d107d3f0b4e8b53fa04c992beb Mon Sep 17 00:00:00 2001 From: Sebastian Dittert Date: Tue, 3 Oct 2023 19:05:17 +0200 Subject: [PATCH] [Algorithm] Update SAC Example (#1524) Co-authored-by: vmoens --- .../linux_examples/scripts/run_test.sh | 14 +- examples/sac/config.yaml | 40 ++--- examples/sac/sac.py | 158 +++++++++++------- examples/sac/utils.py | 104 ++++++++---- torchrl/objectives/sac.py | 5 +- 5 files changed, 197 insertions(+), 124 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 2a9e258c35a..a6e09a51a43 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -118,11 +118,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ - collector.num_workers=4 \ collector.env_per_collector=2 \ collector.collector_device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ network.device=cuda:0 \ @@ -221,17 +220,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ - collector.num_workers=2 \ collector.env_per_collector=1 \ collector.collector_device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ network.device=cuda:0 \ - optimization.batch_size=10 \ - optimization.utd_ratio=1 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ logger.backend= -# record_video=True \ -# record_frames=4 \ python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \ total_frames=48 \ batch_size=10 \ diff --git a/examples/sac/config.yaml b/examples/sac/config.yaml index 22cba615d30..2d3425a2151 100644 --- a/examples/sac/config.yaml +++ b/examples/sac/config.yaml @@ -1,41 +1,41 @@ -# Environment +# environment and task env: name: HalfCheetah-v3 task: "" - exp_name: "HalfCheetah-SAC" - library: gym - frame_skip: 1 - seed: 1 + exp_name: ${env.name}_SAC + library: gymnasium + max_episode_steps: 1000 + seed: 42 -# Collection +# collector collector: - total_frames: 1000000 - init_random_frames: 10000 + total_frames: 1_000_000 + init_random_frames: 25000 frames_per_batch: 1000 - max_frames_per_traj: 1000 init_env_steps: 1000 - async_collection: 1 collector_device: cpu env_per_collector: 1 - num_workers: 1 + reset_at_each_iter: False -# Replay Buffer +# replay buffer replay_buffer: size: 1000000 prb: 0 # use prioritized experience replay + scratch_dir: ${env.exp_name}_${env.seed} -# Optimization -optimization: +# optim +optim: utd_ratio: 1.0 gamma: 0.99 - loss_function: smooth_l1 - lr: 3e-4 - weight_decay: 2e-4 - lr_scheduler: "" + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 batch_size: 256 target_update_polyak: 0.995 + alpha_init: 1.0 + adam_eps: 1.0e-8 -# Algorithm +# network network: hidden_sizes: [256, 256] activation: relu @@ -43,7 +43,7 @@ network: scale_lb: 0.1 device: "cuda:0" -# Logging +# logging logger: backend: wandb mode: online diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 17b997cfda6..33b932ec42c 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -11,17 +11,20 @@ The helper functions are coded in the utils.py associated with this script. """ +import time + import hydra import numpy as np import torch import torch.cuda import tqdm - +from tensordict import TensorDict from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + log_metrics, make_collector, make_environment, make_loss_module, @@ -35,6 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + # Create logger exp_name = generate_exp_name("SAC", cfg.env.exp_name) logger = None if cfg.logger.backend: @@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821 torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - # Create Environments + # Create environments train_env, eval_env = make_environment(cfg) - # Create Agent + + # Create agent model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) - # Create TD3 loss + # Create SAC loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Make Off-Policy Collector + # Create off-policy collector collector = make_collector(cfg, train_env, exploration_policy) - # Make Replay Buffer + # Create replay buffer replay_buffer = make_replay_buffer( - batch_size=cfg.optimization.batch_size, + batch_size=cfg.optim.batch_size, prb=cfg.replay_buffer.prb, buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir, device=device, ) - # Make Optimizers - optimizer = make_sac_optimizer(cfg, loss_module) - - rewards = [] - rewards_eval = [] + # Create optimizers + ( + optimizer_actor, + optimizer_critic, + optimizer_alpha, + ) = make_sac_optimizer(cfg, loss_module) # Main loop + start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - r0 = None - q_loss = None init_random_frames = cfg.collector.init_random_frames num_updates = int( cfg.collector.env_per_collector * cfg.collector.frames_per_batch - * cfg.optimization.utd_ratio + * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - env_per_collector = cfg.collector.env_per_collector eval_iter = cfg.logger.eval_iter - frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip - eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.env.max_episode_steps + sampling_start = time.time() for i, tensordict in enumerate(collector): - # update weights of the inference policy + sampling_time = time.time() - sampling_start + + # Update weights of the inference policy collector.update_policy_weights_() - if r0 is None: - r0 = tensordict["next", "reward"].sum(-1).mean().item() pbar.update(tensordict.numel()) - tensordict = tensordict.view(-1) + tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() + # Add to replay buffer replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames - # optimization steps + # Optimization steps + training_start = time.time() if collected_frames >= init_random_frames: - (actor_losses, q_losses, alpha_losses) = ([], [], []) - for _ in range(num_updates): - # sample from replay buffer + losses = TensorDict( + {}, + batch_size=[ + num_updates, + ], + ) + for i in range(num_updates): + # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() + # Compute loss loss_td = loss_module(sampled_tensordict) actor_loss = loss_td["loss_actor"] q_loss = loss_td["loss_qvalue"] alpha_loss = loss_td["loss_alpha"] - loss = actor_loss + q_loss + alpha_loss - optimizer.zero_grad() - loss.backward() - optimizer.step() + # Update actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() - q_losses.append(q_loss.item()) - actor_losses.append(actor_loss.item()) - alpha_losses.append(alpha_loss.item()) + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() - # update qnet_target params + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() + + losses[i] = loss_td.select( + "loss_actor", "loss_qvalue", "loss_alpha" + ).detach() + + # Update qnet_target params target_net_updater.step() - # update priority + # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - rewards.append( - (i, tensordict["next", "reward"].sum().item() / env_per_collector) + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] ) - train_log = { - "train_reward": rewards[-1][1], - "collected_frames": collected_frames, - } - if q_loss is not None: - train_log.update( - { - "actor_loss": np.mean(actor_losses), - "q_loss": np.mean(q_losses), - "alpha_loss": np.mean(alpha_losses), - "alpha": loss_td["alpha"], - "entropy": loss_td["entropy"], - } + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length ) - if logger is not None: - for key, value in train_log.items(): - logger.log_scalar(key, value, step=collected_frames) - if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip: + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item() + metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() + metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() + metrics_to_log["train/alpha"] = loss_td["alpha"].item() + metrics_to_log["train/entropy"] = loss_td["entropy"].item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) + eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - rewards_eval.append((i, eval_reward)) - eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})" - if logger is not None: - logger.log_scalar( - "evaluation_reward", rewards_eval[-1][1], step=collected_frames - ) - if len(rewards_eval): - pbar.set_description( - f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str - ) + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/examples/sac/utils.py b/examples/sac/utils.py index 9c6f71ffa6c..ebbee32057b 100644 --- a/examples/sac/utils.py +++ b/examples/sac/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import torch from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor @@ -6,8 +11,8 @@ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import RewardScaling +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal @@ -20,16 +25,22 @@ # ----------------- -def env_maker(task, frame_skip=1, device="cpu", from_pixels=False): - return GymEnv(task, device=device, frame_skip=frame_skip, from_pixels=from_pixels) +def env_maker(task, device="cpu"): + with set_gym_backend("gym"): + return GymEnv( + task, + device=device, + ) -def apply_env_transforms(env, reward_scaling=1.0): +def apply_env_transforms(env, max_episode_steps=1000): transformed_env = TransformedEnv( env, Compose( - RewardScaling(loc=0.0, scale=reward_scaling), + InitTracker(), + StepCounter(max_episode_steps), DoubleToFloat(), + RewardSum(), ), ) return transformed_env @@ -43,7 +54,7 @@ def make_environment(cfg): ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms(parallel_env) + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) eval_env = TransformedEnv( ParallelEnv( @@ -65,8 +76,8 @@ def make_collector(cfg, train_env, actor_model_explore): collector = SyncDataCollector( train_env, actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, - max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=cfg.collector.collector_device, ) @@ -114,17 +125,6 @@ def make_replay_buffer( # ----- -def get_activation(cfg): - if cfg.network.activation == "relu": - return nn.ReLU - elif cfg.network.activation == "tanh": - return nn.Tanh - elif cfg.network.activation == "leaky_relu": - return nn.LeakyReLU - else: - raise NotImplementedError - - def make_sac_agent(cfg, train_env, eval_env, device): """Make SAC agent.""" # Define Actor Network @@ -214,24 +214,68 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optimization.loss_function, + loss_function=cfg.optim.loss_function, delay_actor=False, delay_qvalue=True, + alpha_init=cfg.optim.alpha_init, ) - loss_module.make_value_estimator(gamma=cfg.optimization.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate( - loss_module, eps=cfg.optimization.target_update_polyak - ) + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) return loss_module, target_net_updater +def split_critic_params(critic_params): + critic1_params = [] + critic2_params = [] + + for param in critic_params: + data1, data2 = param.data.chunk(2, dim=0) + critic1_params.append(nn.Parameter(data1)) + critic2_params.append(nn.Parameter(data2)) + return critic1_params, critic2_params + + def make_sac_optimizer(cfg, loss_module): - """Make SAC optimizer.""" - optimizer = optim.Adam( - loss_module.parameters(), - lr=cfg.optimization.lr, - weight_decay=cfg.optimization.weight_decay, + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, ) - return optimizer + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=3.0e-4, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 09ca452fa19..4baf4a92d06 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -585,7 +585,8 @@ def _actor_loss( td_q = tensordict.select(*self.qvalue_network.in_keys) td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( - td_q, self._cached_detached_qvalue_params # should we clone? + td_q, + self._cached_detached_qvalue_params, # should we clone? ) min_q_logprob = ( td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) @@ -711,7 +712,7 @@ def _qvalue_v2_loss( pred_val, target_value.expand_as(pred_val), loss_function=self.loss_function, - ).mean(0) + ).sum(0) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata