Skip to content

Commit

Permalink
add configs, data, images
Browse files Browse the repository at this point in the history
  • Loading branch information
sukiboo committed Aug 14, 2022
1 parent b180276 commit 5dce687
Show file tree
Hide file tree
Showing 19 changed files with 277 additions and 21 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
data/*
images/*
configs/*

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Install the requirements via `pip install -r requirements.txt`.
### Run Experiments
Run the experiments via `python -m run_experiment -c config`, where `config` is a configuration file in `./configs/` directory.
The available values are `{config_mnist, config_cifar10, config_spotify, config_recogym, config_personalization}`, which could be specified to recreate each of the presented numerical experiments.
Optionally, a custom experiment can be set up by changing the `./configs/config.yml` file.
Optionally, a custom experiment can be set up by changing or adding new configuration file.

### Load Experiments
All previously performed experiments are stored in `./data/` directory and can be recreated by loading via `python -m run_experiment -l exp_name`, where `exp_name` is the name of the experiment as it is saved in `./data/`.
Expand Down
26 changes: 26 additions & 0 deletions configs/config_cifar10.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: cifar10
seed: 2022

env_name: cifar10

params_exp:
arch: [128,128]
num_timesteps: 0.5e+6
eval_interval: 2500

params_agents:

'A2C':
alg_type: A2C
hyperparams: {learning_rate: 1.0e-4}
batch_size: 32
'DQN':
alg_type: DQN
hyperparams: {learning_rate: 5.0e-6}
batch_size: 32
'PPO':
alg_type: PPO
hyperparams: {learning_rate: 1.0e-4}
batch_size: 32

26 changes: 26 additions & 0 deletions configs/config_mnist.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: mnist
seed: 2022

env_name: mnist

params_exp:
arch: [64,64]
num_timesteps: 0.5e+5
eval_interval: 250

params_agents:

'A2C':
alg_type: A2C
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32
'DQN':
alg_type: DQN
hyperparams: {learning_rate: 3.0e-4}
batch_size: 32
'PPO':
alg_type: PPO
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32

26 changes: 26 additions & 0 deletions configs/config_personalization.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: personalization
seed: 2022

env_name: personalization

params_exp:
arch: [512,512]
num_timesteps: 1.0e+5
eval_interval: 500

params_agents:

'A2C':
alg_type: A2C
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32
'DQN':
alg_type: DQN
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32
'PPO':
alg_type: PPO
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32

26 changes: 26 additions & 0 deletions configs/config_recogym.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: recogym
seed: 2022

env_name: recogym

params_exp:
arch: [512,512]
num_timesteps: 5.0e+5
eval_interval: 2500

params_agents:

'A2C':
alg_type: A2C
hyperparams: {learning_rate: 1.0e-4}
batch_size: 32
'DQN':
alg_type: DQN
hyperparams: {learning_rate: 1.0e-4}
batch_size: 32
'PPO':
alg_type: PPO
hyperparams: {learning_rate: 1.0e-4}
batch_size: 32

26 changes: 26 additions & 0 deletions configs/config_spotify.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

exp_name: spotify
seed: 2022

env_name: spotify

params_exp:
arch: [256,256]
num_timesteps: 0.2e+5
eval_interval: 100

params_agents:

'A2C':
alg_type: A2C
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32
'DQN':
alg_type: DQN
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32
'PPO':
alg_type: PPO
hyperparams: {learning_rate: 1.0e-3}
batch_size: 32

Binary file added data/cifar10.pkl
Binary file not shown.
25 changes: 25 additions & 0 deletions data/cifar10.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
env_name: cifar10
exp_name: cifar10
params_agents:
A2C:
alg_type: A2C
batch_size: 32
hyperparams:
learning_rate: 0.0001
DQN:
alg_type: DQN
batch_size: 32
hyperparams:
learning_rate: 5.0e-06
PPO:
alg_type: PPO
batch_size: 32
hyperparams:
learning_rate: 0.0001
params_exp:
arch:
- 128
- 128
eval_interval: 2500
num_timesteps: 500000.0
seed: 2022
Binary file added data/mnist.pkl
Binary file not shown.
25 changes: 25 additions & 0 deletions data/mnist.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
env_name: mnist
exp_name: mnist
params_agents:
A2C:
alg_type: A2C
batch_size: 32
hyperparams:
learning_rate: 0.001
DQN:
alg_type: DQN
batch_size: 32
hyperparams:
learning_rate: 0.0003
PPO:
alg_type: PPO
batch_size: 32
hyperparams:
learning_rate: 0.001
params_exp:
arch:
- 64
- 64
eval_interval: 250
num_timesteps: 50000.0
seed: 2022
Binary file added data/personalization.pkl
Binary file not shown.
25 changes: 25 additions & 0 deletions data/personalization.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
env_name: personalization
exp_name: personalization
params_agents:
A2C:
alg_type: A2C
batch_size: 32
hyperparams:
learning_rate: 0.001
DQN:
alg_type: DQN
batch_size: 32
hyperparams:
learning_rate: 0.001
PPO:
alg_type: PPO
batch_size: 32
hyperparams:
learning_rate: 0.001
params_exp:
arch:
- 512
- 512
eval_interval: 500
num_timesteps: 100000.0
seed: 2022
Binary file added data/recogym.pkl
Binary file not shown.
25 changes: 25 additions & 0 deletions data/recogym.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
env_name: recogym
exp_name: recogym
params_agents:
A2C:
alg_type: A2C
batch_size: 32
hyperparams:
learning_rate: 0.0001
DQN:
alg_type: DQN
batch_size: 32
hyperparams:
learning_rate: 0.0001
PPO:
alg_type: PPO
batch_size: 32
hyperparams:
learning_rate: 0.0001
params_exp:
arch:
- 512
- 512
eval_interval: 2500
num_timesteps: 500000.0
seed: 2022
Binary file added data/spotify.pkl
Binary file not shown.
25 changes: 25 additions & 0 deletions data/spotify.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
env_name: spotify
exp_name: spotify
params_agents:
A2C:
alg_type: A2C
batch_size: 32
hyperparams:
learning_rate: 0.001
DQN:
alg_type: DQN
batch_size: 32
hyperparams:
learning_rate: 0.001
PPO:
alg_type: PPO
batch_size: 32
hyperparams:
learning_rate: 0.001
params_exp:
arch:
- 256
- 256
eval_interval: 100
num_timesteps: 20000.0
seed: 2022
2 changes: 1 addition & 1 deletion main.py → run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--load', '-l', default=False)
parser.add_argument('--config', '-c', default='config')
parser.add_argument('--config', '-c', default='config_mnist')
args = parser.parse_args()

# obtain experiment data
Expand Down
36 changes: 20 additions & 16 deletions visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
sbn_mute = ['#66aadd', '#ddaa66', '#66ddaa', '#aa66dd', '#dd66aa', '#aadd66']
sbn_base = np.array([sbn_bold, sbn_mute]).flatten(order='C')
sbn_pair = np.array([sbn_bold, sbn_mute]).flatten(order='F')
sns.set_theme(style='darkgrid', palette=sbn_base, font='monospace')
sns.set_theme(style='darkgrid', palette=sbn_base, font='monospace', font_scale=1.5)


class DataVisualization:
Expand Down Expand Up @@ -49,54 +49,58 @@ def load_data(self, exp_name):
def visualize_agents(self):
'''plot various metrics'''
self.plot_distributions(show=False)
self.plot_distributions(sort=False, show=False)
self.plot_entropy(show=False)
self.plot_rewards(show=True)

def plot_distributions(self, step=1, sort=True, show=True):
'''plot agents' test histograms throughout the training process'''
sns.set_palette('Paired')
for name, agent in self.agents.items():
fig, ax = plt.subplots(figsize=(6,4))
fig, ax = plt.subplots(figsize=(6,3))
df = pd.DataFrame(agent['hist'], index=agent['eval_steps'])[::step]
if sort:
df.values[:,::-1].sort(axis=1)
df.plot.bar(stacked=True, width=1, ax=ax, linewidth=.1, legend=None)
plt.xticks(np.linspace(0, len(df) - 1, 7), rotation=0)
ax.set_ylim(0,1)
ax.set_xlabel('number of agent-environment interactions')
ax.set_ylabel(f'action distribution')
##ax.set_xlabel('number of agent-environment interactions')
##ax.set_ylabel(f'action distribution')
plt.tight_layout()
plt.savefig(f'./images/{self.exp_name}_dist_{name}.png', dpi=300, format='png')
if sort:
plt.savefig(f'./images/{self.exp_name}_dist_{name}.png', dpi=300, format='png')
else:
plt.savefig(f'./images/{self.exp_name}_dist_{name}_raw.png', dpi=300, format='png')
if show:
plt.show()
else:
plt.close()

def plot_entropy(self, show=True):
def plot_entropy(self, step=1, show=True):
'''plot agents' entropy throughout the training process'''
sns.set_palette(sbn_base)
fig, ax = plt.subplots(figsize=(9,4))
fig, ax = plt.subplots(figsize=(9,3))
for name, agent in self.agents.items():
agent_ent = [np.sum(-np.array(h) * np.log(h)) for h in agent['hist']]
plt.plot(agent['eval_steps'], agent_ent, linewidth=4, label=name)
ax.set_xlabel('number of agent-environment interactions')
ax.set_ylabel('entropy')
ax.legend(loc='lower left')
plt.plot(agent['eval_steps'][::step], agent_ent[::step], linewidth=4, label=name)
##ax.set_xlabel('number of agent-environment interactions')
##ax.set_ylabel('entropy')
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig(f'./images/{self.exp_name}_entropy.pdf', format='pdf')
if show:
plt.show()
else:
plt.close()

def plot_rewards(self, show=True):
def plot_rewards(self, step=1, show=True):
'''plot evaluation rewards throughout the training process'''
sns.set_palette(sbn_base)
fig, ax = plt.subplots(figsize=(9,4))
fig, ax = plt.subplots(figsize=(9,3))
for name, agent in self.agents.items():
plt.plot(agent['eval_steps'], agent['eval'], linewidth=4, label=name)
ax.set_xlabel('number of agent-environment interactions')
ax.set_ylabel('stochastic evaluation rewards')
plt.plot(agent['eval_steps'][::step], agent['eval'][::step], linewidth=4, label=name)
##ax.set_xlabel('number of agent-environment interactions')
##ax.set_ylabel('stochastic evaluation rewards')
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig(f'./images/{self.exp_name}_reward.pdf', format='pdf')
Expand Down

0 comments on commit 5dce687

Please sign in to comment.