diff --git a/brax/envs/inverted_double_pendulum.py b/brax/envs/inverted_double_pendulum.py index b7dd2b603..d296a84d1 100644 --- a/brax/envs/inverted_double_pendulum.py +++ b/brax/envs/inverted_double_pendulum.py @@ -62,7 +62,7 @@ class InvertedDoublePendulum(PipelineEnv): pendulum system, followed by the velocities of those individual parts (their derivatives) with all the positions ordered before all the velocities. - The observation is a `ndarray` with shape `(11,)` where the elements + The observation is a `ndarray` with shape `(8,)` where the elements correspond to the following: | Num | Observation | @@ -96,9 +96,19 @@ class InvertedDoublePendulum(PipelineEnv): ### Rewards - The goal is to make the inverted pendulum stand upright (within a certain - angle limit) as long as possible - as such a reward of +1 is awarded for each - timestep that the pole is upright. + The total reward is: ***reward*** *=* *alive_bonus - distance_penalty - velocity_penalty*. + + - *alive_bonus*: + Every timestep that the Inverted Pendulum is healthy (see definition in section "Episode Termination"), + it gets a reward of fixed value `healthy_reward` (default is $10$). + - *distance_penalty*: + This reward is a measure of how far the *tip* of the second pendulum (the only free end) moves, + and it is calculated as $0.01 x_{pole2-tip}^2 + (y_{pole2-tip}-2)^2$, + where $x_{pole2-tip}, y_{pole2-tip}$ are the xy-coordinatesof the tip of the second pole. + - *velocity_penalty*: + A negative reward to penalize the agent for moving too fast. + $10^{-3} \omega_1 + 5 \times 10^{-3} \omega_2$, + where $\omega_1, \omega_2$ are the angular velocities of the hinges. ### Starting State @@ -107,11 +117,11 @@ class InvertedDoublePendulum(PipelineEnv): ### Episode Termination - The episode terminates when any of the following happens: - - 1. The episode duration reaches 1000 timesteps. - 2. The absolute value of the vertical angle between the pole and the cart is - greater than 0.2 radians. + The episode terminates when the y_coordinate of the tip of the second + pole $\leq 1$. + + Note: The maximum standing height of the system is 1.2 m when all the parts + are perpendicularly vertical on top of each other. """ # pyformat: enable @@ -152,18 +162,19 @@ def step(self, state: State, action: jax.Array) -> State: """Run one timestep of the environment's dynamics.""" pipeline_state = self.pipeline_step(state.pipeline_state, action) - tip = base.Transform.create(pos=jp.array([0.0, 0.0, 0.6])).do( - pipeline_state.x.take(2) + tip = pipeline_state.x.take(2).do( + base.Transform.create(pos=jp.array([0.0, 0.0, 0.6])) ) x, _, y = tip.pos - dist_penalty = 0.01 * x**2 + (y - 2) ** 2 v1, v2 = pipeline_state.qd[1:] + + dist_penalty = 0.01 * x**2 + (y - 2) ** 2 vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2 alive_bonus = 10 obs = self._get_obs(pipeline_state) - reward = alive_bonus - dist_penalty - vel_penalty done = jp.where(y <= 1, jp.float32(1), jp.float32(0)) + reward = (1 - done) * alive_bonus - dist_penalty - vel_penalty return state.replace( pipeline_state=pipeline_state, obs=obs, reward=reward, done=done