forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgym_transforms.py
200 lines (176 loc) · 9.59 KB
/
gym_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Gym-specific transforms."""
import warnings
import torch
import torchrl.objectives.common
from tensordict import TensorDictBase
from tensordict.utils import expand_as_right, NestedKey
from torchrl.data.tensor_specs import UnboundedDiscreteTensorSpec
from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform
class EndOfLifeTransform(Transform):
"""Registers the end-of-life signal from a Gym env with a `lives` method.
Proposed by DeepMind for the DQN and co. It helps value estimation.
Args:
eol_key (NestedKey, optional): the key where the end-of-life signal should
be written. Defaults to ``"end-of-life"``.
done_key (NestedKey, optional): a "done" key in the parent env done_spec,
where the done value can be retrieved. This key must be unique and its
shape must match the shape of the end-of-life entry. Defaults to ``"done"``.
eol_attribute (str, optional): the location of the "lives" in the gym env.
Defaults to ``"unwrapped.ale.lives"``. Supported attribute types are
integer/array-like objects or callables that return these values.
.. note::
This transform should be used with gym envs that have a ``env.unwrapped.ale.lives``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms.transforms import TransformedEnv
>>> env = GymEnv("ALE/Breakout-v5")
>>> env.rollout(100)
TensorDict(
fields={
action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False),
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False)
>>> eol_transform = EndOfLifeTransform()
>>> env = TransformedEnv(env, eol_transform)
>>> env.rollout(100)
TensorDict(
fields={
action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
eol: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
end-of-life: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False),
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False)
The typical usage of this transform is to replace the "done" state by "end-of-life"
within the loss module. The end-of-life signal isn't registered within the ``done_spec``
because it should not instruct the env to reset.
Examples:
>>> from torchrl.objectives import DQNLoss
>>> module = torch.nn.Identity() # used as a placeholder
>>> loss = DQNLoss(module, action_space="categorical")
>>> loss.set_keys(done="end-of-life", terminated="end-of-life")
>>> # equivalently
>>> eol_transform.register_keys(loss)
"""
NO_PARENT_ERR = "The {} transform is being executed without a parent env. This is currently not supported."
def __init__(
self,
eol_key: NestedKey = "end-of-life",
lives_key: NestedKey = "lives",
done_key: NestedKey = "done",
eol_attribute="unwrapped.ale.lives",
):
super().__init__(in_keys=[done_key], out_keys=[eol_key, lives_key])
self.eol_key = eol_key
self.lives_key = lives_key
self.done_key = done_key
self.eol_attribute = eol_attribute.split(".")
def _get_lives(self):
from torchrl.envs.libs.gym import GymWrapper
base_env = self.parent.base_env
if not isinstance(base_env, GymWrapper):
warnings.warn(
f"The base_env is not a gym env. Compatibility of {type(self)} is not guaranteed with "
f"environment types that do not inherit from GymWrapper.",
category=UserWarning,
)
# getattr falls back on _env by default
lives = getattr(base_env, self.eol_attribute[0])
for att in self.eol_attribute[1:]:
if isinstance(lives, list):
# For SerialEnv (and who knows Parallel one day)
lives = [getattr(_lives, att) for _lives in lives]
else:
lives = getattr(lives, att)
if callable(lives):
lives = lives()
elif isinstance(lives, list) and all(callable(_lives) for _lives in lives):
lives = torch.tensor([_lives() for _lives in lives])
return lives
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict
def _step(self, tensordict, next_tensordict):
parent = self.parent
if parent is None:
raise RuntimeError(self.NO_PARENT_ERR.format(type(self)))
lives = self._get_lives()
end_of_life = torch.tensor(
tensordict.get(self.lives_key) < lives, device=self.parent.device
)
try:
done = next_tensordict.get(self.done_key)
except KeyError:
raise KeyError(
f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. "
f"Make sure to pass the appropriate done_key to the {type(self)} transform."
)
end_of_life = expand_as_right(end_of_life, done) | done
next_tensordict.set(self.eol_key, end_of_life)
next_tensordict.set(self.lives_key, lives)
return next_tensordict
def reset(self, tensordict):
parent = self.parent
if parent is None:
raise RuntimeError(self.NO_PARENT_ERR.format(type(self)))
lives = self._get_lives()
end_of_life = False
tensordict.set(
self.eol_key,
torch.tensor(end_of_life).expand(
parent.full_done_spec[self.done_key].shape
),
)
tensordict.set(self.lives_key, lives)
return tensordict
def transform_observation_spec(self, observation_spec):
full_done_spec = self.parent.output_spec["full_done_spec"]
observation_spec[self.eol_key] = full_done_spec[self.done_key].clone()
observation_spec[self.lives_key] = UnboundedDiscreteTensorSpec(
self.parent.batch_size,
device=self.parent.device,
dtype=torch.int64,
)
return observation_spec
def register_keys(self, loss_or_advantage: "torchrl.objectives.common.LossModule"):
"""Registers the end-of-life key at appropriate places within the loss.
Args:
loss_or_advantage (torchrl.objectives.LossModule or torchrl.objectives.value.ValueEstimatorBase): a module to instruct what the end-of-life key is.
"""
loss_or_advantage.set_keys(done=self.eol_key, terminated=self.eol_key)
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))