diff --git a/.gitignore b/.gitignore index 57c1273..8b13789 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1 @@ -data/* -images/* -configs/* diff --git a/README.md b/README.md index d4ff9b6..a7acfcb 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Install the requirements via `pip install -r requirements.txt`. ### Run Experiments Run the experiments via `python -m run_experiment -c config`, where `config` is a configuration file in `./configs/` directory. The available values are `{config_mnist, config_cifar10, config_spotify, config_recogym, config_personalization}`, which could be specified to recreate each of the presented numerical experiments. -Optionally, a custom experiment can be set up by changing the `./configs/config.yml` file. +Optionally, a custom experiment can be set up by changing or adding new configuration file. ### Load Experiments All previously performed experiments are stored in `./data/` directory and can be recreated by loading via `python -m run_experiment -l exp_name`, where `exp_name` is the name of the experiment as it is saved in `./data/`. diff --git a/configs/config_cifar10.yml b/configs/config_cifar10.yml new file mode 100644 index 0000000..48e4680 --- /dev/null +++ b/configs/config_cifar10.yml @@ -0,0 +1,26 @@ + +exp_name: cifar10 +seed: 2022 + +env_name: cifar10 + +params_exp: + arch: [128,128] + num_timesteps: 0.5e+6 + eval_interval: 2500 + +params_agents: + + 'A2C': + alg_type: A2C + hyperparams: {learning_rate: 1.0e-4} + batch_size: 32 + 'DQN': + alg_type: DQN + hyperparams: {learning_rate: 5.0e-6} + batch_size: 32 + 'PPO': + alg_type: PPO + hyperparams: {learning_rate: 1.0e-4} + batch_size: 32 + diff --git a/configs/config_mnist.yml b/configs/config_mnist.yml new file mode 100644 index 0000000..ea9ec7a --- /dev/null +++ b/configs/config_mnist.yml @@ -0,0 +1,26 @@ + +exp_name: mnist +seed: 2022 + +env_name: mnist + +params_exp: + arch: [64,64] + num_timesteps: 0.5e+5 + eval_interval: 250 + +params_agents: + + 'A2C': + alg_type: A2C + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + 'DQN': + alg_type: DQN + hyperparams: {learning_rate: 3.0e-4} + batch_size: 32 + 'PPO': + alg_type: PPO + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + diff --git a/configs/config_personalization.yml b/configs/config_personalization.yml new file mode 100644 index 0000000..8b18542 --- /dev/null +++ b/configs/config_personalization.yml @@ -0,0 +1,26 @@ + +exp_name: personalization +seed: 2022 + +env_name: personalization + +params_exp: + arch: [512,512] + num_timesteps: 1.0e+5 + eval_interval: 500 + +params_agents: + + 'A2C': + alg_type: A2C + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + 'DQN': + alg_type: DQN + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + 'PPO': + alg_type: PPO + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + diff --git a/configs/config_recogym.yml b/configs/config_recogym.yml new file mode 100644 index 0000000..6b39ece --- /dev/null +++ b/configs/config_recogym.yml @@ -0,0 +1,26 @@ + +exp_name: recogym +seed: 2022 + +env_name: recogym + +params_exp: + arch: [512,512] + num_timesteps: 5.0e+5 + eval_interval: 2500 + +params_agents: + + 'A2C': + alg_type: A2C + hyperparams: {learning_rate: 1.0e-4} + batch_size: 32 + 'DQN': + alg_type: DQN + hyperparams: {learning_rate: 1.0e-4} + batch_size: 32 + 'PPO': + alg_type: PPO + hyperparams: {learning_rate: 1.0e-4} + batch_size: 32 + diff --git a/configs/config_spotify.yml b/configs/config_spotify.yml new file mode 100644 index 0000000..c34c31b --- /dev/null +++ b/configs/config_spotify.yml @@ -0,0 +1,26 @@ + +exp_name: spotify +seed: 2022 + +env_name: spotify + +params_exp: + arch: [256,256] + num_timesteps: 0.2e+5 + eval_interval: 100 + +params_agents: + + 'A2C': + alg_type: A2C + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + 'DQN': + alg_type: DQN + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + 'PPO': + alg_type: PPO + hyperparams: {learning_rate: 1.0e-3} + batch_size: 32 + diff --git a/data/cifar10.pkl b/data/cifar10.pkl new file mode 100644 index 0000000..60d5335 Binary files /dev/null and b/data/cifar10.pkl differ diff --git a/data/cifar10.yml b/data/cifar10.yml new file mode 100644 index 0000000..60946e7 --- /dev/null +++ b/data/cifar10.yml @@ -0,0 +1,25 @@ +env_name: cifar10 +exp_name: cifar10 +params_agents: + A2C: + alg_type: A2C + batch_size: 32 + hyperparams: + learning_rate: 0.0001 + DQN: + alg_type: DQN + batch_size: 32 + hyperparams: + learning_rate: 5.0e-06 + PPO: + alg_type: PPO + batch_size: 32 + hyperparams: + learning_rate: 0.0001 +params_exp: + arch: + - 128 + - 128 + eval_interval: 2500 + num_timesteps: 500000.0 +seed: 2022 diff --git a/data/mnist.pkl b/data/mnist.pkl new file mode 100644 index 0000000..135ab50 Binary files /dev/null and b/data/mnist.pkl differ diff --git a/data/mnist.yml b/data/mnist.yml new file mode 100644 index 0000000..21ea6b6 --- /dev/null +++ b/data/mnist.yml @@ -0,0 +1,25 @@ +env_name: mnist +exp_name: mnist +params_agents: + A2C: + alg_type: A2C + batch_size: 32 + hyperparams: + learning_rate: 0.001 + DQN: + alg_type: DQN + batch_size: 32 + hyperparams: + learning_rate: 0.0003 + PPO: + alg_type: PPO + batch_size: 32 + hyperparams: + learning_rate: 0.001 +params_exp: + arch: + - 64 + - 64 + eval_interval: 250 + num_timesteps: 50000.0 +seed: 2022 diff --git a/data/personalization.pkl b/data/personalization.pkl new file mode 100644 index 0000000..8dbe82e Binary files /dev/null and b/data/personalization.pkl differ diff --git a/data/personalization.yml b/data/personalization.yml new file mode 100644 index 0000000..c32b888 --- /dev/null +++ b/data/personalization.yml @@ -0,0 +1,25 @@ +env_name: personalization +exp_name: personalization +params_agents: + A2C: + alg_type: A2C + batch_size: 32 + hyperparams: + learning_rate: 0.001 + DQN: + alg_type: DQN + batch_size: 32 + hyperparams: + learning_rate: 0.001 + PPO: + alg_type: PPO + batch_size: 32 + hyperparams: + learning_rate: 0.001 +params_exp: + arch: + - 512 + - 512 + eval_interval: 500 + num_timesteps: 100000.0 +seed: 2022 diff --git a/data/recogym.pkl b/data/recogym.pkl new file mode 100644 index 0000000..c1cc732 Binary files /dev/null and b/data/recogym.pkl differ diff --git a/data/recogym.yml b/data/recogym.yml new file mode 100644 index 0000000..6884480 --- /dev/null +++ b/data/recogym.yml @@ -0,0 +1,25 @@ +env_name: recogym +exp_name: recogym +params_agents: + A2C: + alg_type: A2C + batch_size: 32 + hyperparams: + learning_rate: 0.0001 + DQN: + alg_type: DQN + batch_size: 32 + hyperparams: + learning_rate: 0.0001 + PPO: + alg_type: PPO + batch_size: 32 + hyperparams: + learning_rate: 0.0001 +params_exp: + arch: + - 512 + - 512 + eval_interval: 2500 + num_timesteps: 500000.0 +seed: 2022 diff --git a/data/spotify.pkl b/data/spotify.pkl new file mode 100644 index 0000000..9fa8730 Binary files /dev/null and b/data/spotify.pkl differ diff --git a/data/spotify.yml b/data/spotify.yml new file mode 100644 index 0000000..24cf729 --- /dev/null +++ b/data/spotify.yml @@ -0,0 +1,25 @@ +env_name: spotify +exp_name: spotify +params_agents: + A2C: + alg_type: A2C + batch_size: 32 + hyperparams: + learning_rate: 0.001 + DQN: + alg_type: DQN + batch_size: 32 + hyperparams: + learning_rate: 0.001 + PPO: + alg_type: PPO + batch_size: 32 + hyperparams: + learning_rate: 0.001 +params_exp: + arch: + - 256 + - 256 + eval_interval: 100 + num_timesteps: 20000.0 +seed: 2022 diff --git a/main.py b/run_experiment.py similarity index 95% rename from main.py rename to run_experiment.py index 5e86f41..072b6d1 100644 --- a/main.py +++ b/run_experiment.py @@ -12,7 +12,7 @@ # parse arguments parser = argparse.ArgumentParser() parser.add_argument('--load', '-l', default=False) - parser.add_argument('--config', '-c', default='config') + parser.add_argument('--config', '-c', default='config_mnist') args = parser.parse_args() # obtain experiment data diff --git a/visualization.py b/visualization.py index 7684ac6..3836c5c 100644 --- a/visualization.py +++ b/visualization.py @@ -11,7 +11,7 @@ sbn_mute = ['#66aadd', '#ddaa66', '#66ddaa', '#aa66dd', '#dd66aa', '#aadd66'] sbn_base = np.array([sbn_bold, sbn_mute]).flatten(order='C') sbn_pair = np.array([sbn_bold, sbn_mute]).flatten(order='F') -sns.set_theme(style='darkgrid', palette=sbn_base, font='monospace') +sns.set_theme(style='darkgrid', palette=sbn_base, font='monospace', font_scale=1.5) class DataVisualization: @@ -49,6 +49,7 @@ def load_data(self, exp_name): def visualize_agents(self): '''plot various metrics''' self.plot_distributions(show=False) + self.plot_distributions(sort=False, show=False) self.plot_entropy(show=False) self.plot_rewards(show=True) @@ -56,32 +57,35 @@ def plot_distributions(self, step=1, sort=True, show=True): '''plot agents' test histograms throughout the training process''' sns.set_palette('Paired') for name, agent in self.agents.items(): - fig, ax = plt.subplots(figsize=(6,4)) + fig, ax = plt.subplots(figsize=(6,3)) df = pd.DataFrame(agent['hist'], index=agent['eval_steps'])[::step] if sort: df.values[:,::-1].sort(axis=1) df.plot.bar(stacked=True, width=1, ax=ax, linewidth=.1, legend=None) plt.xticks(np.linspace(0, len(df) - 1, 7), rotation=0) ax.set_ylim(0,1) - ax.set_xlabel('number of agent-environment interactions') - ax.set_ylabel(f'action distribution') + ##ax.set_xlabel('number of agent-environment interactions') + ##ax.set_ylabel(f'action distribution') plt.tight_layout() - plt.savefig(f'./images/{self.exp_name}_dist_{name}.png', dpi=300, format='png') + if sort: + plt.savefig(f'./images/{self.exp_name}_dist_{name}.png', dpi=300, format='png') + else: + plt.savefig(f'./images/{self.exp_name}_dist_{name}_raw.png', dpi=300, format='png') if show: plt.show() else: plt.close() - def plot_entropy(self, show=True): + def plot_entropy(self, step=1, show=True): '''plot agents' entropy throughout the training process''' sns.set_palette(sbn_base) - fig, ax = plt.subplots(figsize=(9,4)) + fig, ax = plt.subplots(figsize=(9,3)) for name, agent in self.agents.items(): agent_ent = [np.sum(-np.array(h) * np.log(h)) for h in agent['hist']] - plt.plot(agent['eval_steps'], agent_ent, linewidth=4, label=name) - ax.set_xlabel('number of agent-environment interactions') - ax.set_ylabel('entropy') - ax.legend(loc='lower left') + plt.plot(agent['eval_steps'][::step], agent_ent[::step], linewidth=4, label=name) + ##ax.set_xlabel('number of agent-environment interactions') + ##ax.set_ylabel('entropy') + ax.legend(loc='lower right') plt.tight_layout() plt.savefig(f'./images/{self.exp_name}_entropy.pdf', format='pdf') if show: @@ -89,14 +93,14 @@ def plot_entropy(self, show=True): else: plt.close() - def plot_rewards(self, show=True): + def plot_rewards(self, step=1, show=True): '''plot evaluation rewards throughout the training process''' sns.set_palette(sbn_base) - fig, ax = plt.subplots(figsize=(9,4)) + fig, ax = plt.subplots(figsize=(9,3)) for name, agent in self.agents.items(): - plt.plot(agent['eval_steps'], agent['eval'], linewidth=4, label=name) - ax.set_xlabel('number of agent-environment interactions') - ax.set_ylabel('stochastic evaluation rewards') + plt.plot(agent['eval_steps'][::step], agent['eval'][::step], linewidth=4, label=name) + ##ax.set_xlabel('number of agent-environment interactions') + ##ax.set_ylabel('stochastic evaluation rewards') ax.legend(loc='lower right') plt.tight_layout() plt.savefig(f'./images/{self.exp_name}_reward.pdf', format='pdf')