-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample_training_rllib.py
156 lines (118 loc) · 5.27 KB
/
example_training_rllib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
# encoding: utf-8
"""
This file contains an example of training your agents using Ray RLLib.
Notice that this code will not learn properlu, as no preprocessing of the data nor adaptation of the default PPO algorithm is done here.
"""
from pathlib import Path
from typing import Callable
import gymnasium
import pettingzoo
from gymnasium import spaces
from pettingzoo.utils import BaseWrapper
from pettingzoo.utils.env import AgentID, ObsType
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module import RLModule, MultiRLModule
from ray.rllib.core.rl_module import MultiRLModuleSpec, RLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv, ParallelPettingZooEnv
from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule
from ray.tune.registry import register_env
import numpy as np
import torch
from utils import create_environment
class CustomWrapper(BaseWrapper):
# This is an example of a custom wrapper that flattens the symbolic vector state of the environment
# Wrapper are useful to inject state pre-processing or feature that does not need to be learned by the agent
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return spaces.flatten_space(super().observation_space(agent))
def observe(self, agent: AgentID) -> ObsType | None:
obs = super().observe(agent)
flat_obs = obs.flatten()
return flat_obs
class CustomPredictFunction(Callable):
""" This is an example of an instantiation of the CustomPredictFunction that loads a trained RLLib algorithm from
a checkpoint and extract the policies from it"""
def __init__(self, env):
# Here you should load your trained model(s) from a checkpoint in your folder
best_checkpoint = (Path("results") / "learner_group" / "learner" / "rl_module").resolve()
self.modules = MultiRLModule.from_checkpoint(best_checkpoint)
def __call__(self, observation, agent, *args, **kwargs):
rl_module = self.modules[agent]
fwd_ins = {"obs": torch.Tensor(observation).unsqueeze(0)}
fwd_outputs = rl_module.forward_inference(fwd_ins)
action_dist_class = rl_module.get_inference_action_dist_cls()
action_dist = action_dist_class.from_logits(
fwd_outputs["action_dist_inputs"]
)
action = action_dist.sample()[0].numpy()
return action
def algo_config(id_env, policies, policies_to_train):
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment(env=id_env, disable_env_checking=True)
.env_runners(num_env_runners=1)
.multi_agent(
policies={x for x in policies},
policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
policies_to_train=policies_to_train,
)
.rl_module(
rl_module_spec=MultiRLModuleSpec(
rl_module_specs={
x: RLModuleSpec(module_class=PPOTorchRLModule, model_config={"fcnet_hiddens": [64, 64]})
if x in policies_to_train
else
RLModuleSpec(module_class=RandomRLModule)
for x in policies},
))
.training(
train_batch_size=512,
lr=1e-4,
gamma=0.99,
)
.debugging(log_level="ERROR")
)
return config
def training(env, checkpoint_path, max_iterations = 500):
# Translating the PettingZoo environment to an RLLib environment.
# Note: RLLib use a parallelized version of the environment.
rllib_env = ParallelPettingZooEnv(pettingzoo.utils.conversions.aec_to_parallel(env))
id_env = "knights_archers_zombies_v10"
register_env(id_env, lambda config: rllib_env)
# Fix seeds
np.random.seed(42)
torch.manual_seed(42)
# Define the configuration for the PPO algorithm
policies = [x for x in env.agents]
policies_to_train = policies
config = algo_config(id_env, policies, policies_to_train)
# Train the model
algo = config.build()
for i in range(max_iterations):
result = algo.train()
result.pop("config")
if "env_runners" in result and "agent_episode_returns_mean" in result["env_runners"]:
print(i, result["env_runners"]["agent_episode_returns_mean"])
if result["env_runners"]["agent_episode_returns_mean"]["archer_0"] > 5: # Or any early stopping criterion
break
if i % 5 == 0:
save_result = algo.save(checkpoint_path)
path_to_checkpoint = save_result.checkpoint.path
print(
"An Algorithm checkpoint has been created inside directory: "
f"'{path_to_checkpoint}'."
)
if __name__ == "__main__":
num_agents = 1
visual_observation = False
# Create the PettingZoo environment for training
env = create_environment(num_agents=num_agents, visual_observation=visual_observation)
env = CustomWrapper(env)
# Running training routine
checkpoint_path = str(Path("results").resolve())
training(env, checkpoint_path, max_iterations = 500)