-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy patharm_push_easy.py
124 lines (103 loc) · 4.81 KB
/
arm_push_easy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import jax
from brax import base
from jax import numpy as jnp
from envs.manipulation.arm_envs import ArmEnvs
"""
Push-Easy: Move a cube from a random location on the blue region to a random goal on the adjacent red region. The regions are very small.
- Observation space: 18-dim obs + 3-dim goal.
- Action space: 5-dim, each element in [-1, 1], corresponding to target angles for joints 1, 2, 4, 6, and finger closedness.
See _get_obs() and ArmEnvs._convert_action() for details.
"""
class ArmPushEasy(ArmEnvs):
def _get_xml_path(self):
return "envs/assets/panda_push_easy.xml"
@property
def action_size(self) -> int:
return 5 # Override default (actuator count)
# See ArmEnvs._set_environment_attributes for descriptions of attributes
def _set_environment_attributes(self):
self.env_name = "arm_push_easy"
self.episode_length = 150
self.goal_indices = jnp.array([0, 1, 2]) # Cube position
self.completion_goal_indices = jnp.array([0, 1, 2]) # Identical
self.state_dim = 18
self.goal_reach_thresh = 0.1
self.arm_noise_scale = 0
self.cube_noise_scale = 0.1
self.goal_noise_scale = 0.1
def _get_initial_state(self, rng):
rng, subkey1, subkey2 = jax.random.split(rng, 3)
cube_q_xy = self.sys.init_q[:2] + self.cube_noise_scale * jax.random.uniform(
subkey1, [2], minval=-1
)
cube_q_remaining = self.sys.init_q[2:7]
target_q = self.sys.init_q[7:14]
arm_q_default = jnp.array(
[1.571, 0.742, 0, -1.571, 0, 3.054, 1.449, 0.04, 0.04]
) # Start closer to the relevant area
arm_q = arm_q_default + self.arm_noise_scale * jax.random.uniform(
subkey2, [self.sys.q_size() - 14], minval=-1
)
q = jnp.concatenate([cube_q_xy] + [cube_q_remaining] + [target_q] + [arm_q])
qd = jnp.zeros([self.sys.qd_size()])
return q, qd
def _get_initial_goal(self, pipeline_state: base.State, rng):
rng, subkey = jax.random.split(rng)
cube_goal_pos = jnp.array([0.1, 0.6, 0.03]) + jnp.array(
[self.goal_noise_scale, self.goal_noise_scale, 0]
) * jax.random.uniform(subkey, [3], minval=-1)
return cube_goal_pos
def _compute_goal_completion(self, obs, goal):
# Goal occupancy: is the cube close enough to the goal?
current_cube_pos = obs[self.completion_goal_indices]
goal_pos = goal[:3]
dist = jnp.linalg.norm(current_cube_pos - goal_pos)
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 0.3, dtype=float)
success_hard = jnp.array(dist < 0.03, dtype=float)
return success, success_easy, success_hard
def _update_goal_visualization(
self, pipeline_state: base.State, goal: jax.Array
) -> base.State:
updated_q = pipeline_state.q.at[7:10].set(
goal[:3]
) # Only set the position, not orientation
updated_pipeline_state = pipeline_state.replace(qpos=updated_q)
return updated_pipeline_state
def _get_obs(
self, pipeline_state: base.State, goal: jax.Array, timestep
) -> jax.Array:
"""
Observation space (18-dim)
- q_subset (10-dim): 3-dim cube position, 7-dim joint angles
- End-effector (6-dim): position and velocity
- Fingers (2-dim): finger distance, gripper force
Note q is 23-dim: 7-dim cube position/angle, 7-dim goal marker position/angle, 7-dim joint angles, 2-dim finger offset
Goal space (3-dim): position of cube
"""
q_indices = jnp.array([0, 1, 2, 14, 15, 16, 17, 18, 19, 20])
q_subset = pipeline_state.q[q_indices]
eef_index = 8 # Cube is 0, goal marker is 1, then links 1-7 are indices 2-8. The end-effector (eef) base is merged with link 7, so we say link 7 index = eef index.
eef_x_pos = pipeline_state.x.pos[eef_index]
eef_xd_vel = pipeline_state.xd.vel[eef_index]
left_finger_index = 9
left_finger_x_pos = pipeline_state.x.pos[left_finger_index]
right_finger_index = 10
right_finger_x_pos = pipeline_state.x.pos[right_finger_index]
finger_distance = jnp.linalg.norm(right_finger_x_pos - left_finger_x_pos)[
None
] # [None] expands dims from 0 to 1
gripper_force = (pipeline_state.qfrc_actuator[:-2]).mean(
keepdims=True
) * 0.1 # Normalize it from range [-20, 20] to [-2, 2]
return jnp.concatenate(
[q_subset]
+ [eef_x_pos]
+ [eef_xd_vel]
+ [finger_distance]
+ [gripper_force]
+ [goal]
)
def _get_arm_angles(self, pipeline_state: base.State) -> jax.Array:
q_indices = jnp.arange(14, 21)
return pipeline_state.q[q_indices]