-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdmc.py
106 lines (90 loc) · 3.45 KB
/
dmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import glob
import gym
import numpy as np
import os
from gym import spaces
import local_dm_control_suite as suite
from .img_sources import make_img_source
# https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/wrappers.py
# https://github.com/facebookresearch/deep_bisim4control/blob/main/dmc2gym/wrappers.py
def _flatten_obs(obs):
obs_pieces = []
for v in obs.values():
flat = np.array([v]) if np.isscalar(v) else v.ravel()
obs_pieces.append(flat)
return np.concatenate(obs_pieces, axis=0)
class DMCEnv(gym.Env):
def __init__(
self, name, pixel_obs, img_source, resource_files, total_frames, reset_bg
):
domain, task = name.split("-", 1)
self._env = suite.load(domain, task)
self._pixel_obs = pixel_obs
self._img_source = img_source
self._reset_bg = reset_bg
self._resolution = 64
self._camera_id = 0
if pixel_obs:
img_shape = (3, self._resolution, self._resolution)
self.observation_space = spaces.Box(
low=0, high=255, shape=img_shape, dtype=np.uint8
)
else:
obs_spec = self._env.observation_spec()
obs_len = sum([int(np.prod(s.shape)) for s in obs_spec.values()])
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(obs_len,), dtype=np.float32
)
act_spec = self._env.action_spec()
self.action_space = spaces.Box(
low=act_spec.minimum.astype(np.float32),
high=act_spec.maximum.astype(np.float32),
dtype=np.float32,
)
# Change background
if img_source is not None:
img_shape = (self._resolution, self._resolution)
self._bg_source = make_img_source(
src_type=img_source,
img_shape=img_shape,
resource_files=resource_files,
total_frames=total_frames,
grayscale=True,
)
def seed(self, seed):
self.observation_space.seed(seed)
self.action_space.seed(seed)
def step(self, action):
time_step = self._env.step(action)
obs = self._get_obs(time_step)
reward = time_step.reward
info = dict(discount=time_step.discount)
return obs, reward, False, info
def reset(self):
if self._img_source is not None and self._reset_bg:
self._bg_source.reset()
time_step = self._env.reset()
obs = self._get_obs(time_step)
return obs
def render(self, mode="rgb_array", height=64, width=64, camera_id=0):
assert mode == "rgb_array", "DMC only supports rgb_array render mode"
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
def _get_obs(self, time_step):
if self._pixel_obs:
obs = self.render(
mode="rgb_array",
height=self._resolution,
width=self._resolution,
camera_id=self._camera_id,
)
if self._img_source is not None:
# Hardcoded mask for dmc
mask = np.logical_and(
(obs[:, :, 2] > obs[:, :, 1]), (obs[:, :, 2] > obs[:, :, 0])
)
bg = self._bg_source.get_image()
obs[mask] = bg[mask]
obs = obs.transpose(2, 0, 1).copy()
else:
obs = _flatten_obs(time_step.observation)
return obs