Skip to content

Commit

Permalink
first non-erroring example of sb1
Browse files Browse the repository at this point in the history
  • Loading branch information
rallen10 committed Feb 25, 2023
1 parent 3e7eec6 commit 6349298
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 28 deletions.
12 changes: 6 additions & 6 deletions scripts/sb_objective_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,24 @@
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def sb_objective(r, t, r_d, A=1, B=1):
def sb_objective(r, t, r_d=50, A=1, B=0.001):
return - A * np.exp(-B * (r-r_d)**2) * np.cos(t)

if __name__ == "__main__":

# setup range and angle arrays
r = np.linspace(0, 10 ,100)
r = np.linspace(0, 100 ,100)
t = np.linspace(0, 2*np.pi, 100)
R, T = np.meshgrid(r, t)

# evaluate function on meshgrid
Z = sb_objective(R, T, 2)
Z = sb_objective(R, T)

# Create a 3D plot of the function
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(R*np.cos(T), R*np.sin(T), Z, cmap='viridis')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_xlabel('x-pos [m]')
ax.set_ylabel('y-po [m]')
ax.set_zlabel('reward [-]')
plt.show()
12 changes: 6 additions & 6 deletions src/kspdg/lbg1/lbg1_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,12 @@ def get_observation(self) -> ArrayLike:
[0] : mission elapsed time [s]
[1] : current vehicle (bandit) mass [kg]
[2] : current vehicle (bandit) propellant (mono prop) [kg]
[3:6] : bandit position in reference orbit right-hand CBCI coords [m]
[6:9] : bandit velocity in reference orbit right-hand CBCI coords [m/s]
[9:12] : lady position in reference orbit right-hand CBCI coords [m]
[12:15] : lady velocity in reference orbit right-hand CBCI coords [m/s]
[15:18] : guard position in reference orbit right-hand CBCI coords [m]
[18:21] : guard velocity in reference orbit right-hand CBCI coords [m/s]
[3:6] : bandit position wrt CB in right-hand CBCI coords [m]
[6:9] : bandit velocity wrt CB in right-hand CBCI coords [m/s]
[9:12] : lady position wrt CB in right-hand CBCI coords [m]
[12:15] : lady velocity wrt CB in right-hand CBCI coords [m/s]
[15:18] : guard position wrt CB in right-hand CBCI coords [m]
[18:21] : guard velocity wrt CB in right-hand CBCI coords [m/s]
Ref:
- CBCI stands for celestial-body-centered inertial which is a coralary to ECI coords
Expand Down
11 changes: 6 additions & 5 deletions src/kspdg/pe1/pe1_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def __init__(self, loadfile:str,
super().__init__(**kwargs)

assert episode_timeout > 0
assert capture_dist > 0 or capture_dist is None
if capture_dist is not None:
assert capture_dist > 0
self.episode_timeout = episode_timeout
self.capture_dist = capture_dist

Expand Down Expand Up @@ -355,10 +356,10 @@ def get_observation(self):
[0] : mission elapsed time [s]
[1] : current vehicle (pursuer) mass [kg]
[2] : current vehicle (pursuer) propellant (mono prop) [kg]
[3:6] : pursuer position in reference orbit right-hand CBCI coords [m]
[6:9] : pursuer velocity in reference orbit right-hand CBCI coords [m/s]
[9:12] : evader position in reference orbit right-hand CBCI coords [m]
[12:15] : evader velocity in reference orbit right-hand CBCI coords [m/s]
[3:6] : pursuer position wrt CB in right-hand CBCI coords [m]
[6:9] : pursuer velocity wrt CB in right-hand CBCI coords [m/s]
[9:12] : evader position wrt CB in right-hand CBCI coords [m]
[12:15] : evader velocity wrt CB in right-hand CBCI coords [m/s]
Ref:
- CBCI stands for celestial-body-centered inertial which is a coralary to ECI coords
Expand Down
30 changes: 30 additions & 0 deletions src/kspdg/sb1/e1_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from kspdg.sb1.sb1_base import SunBlockingGroup1Env

class SB1_E1_ParentEnv(SunBlockingGroup1Env):
def __init__(self, loadfile: str, **kwargs):
super().__init__(loadfile=loadfile, **kwargs)

def evasive_maneuvers(self):
'''Do not perform evasive maneuvers
'''
pass

class SB1_E1_I1_Env(SB1_E1_ParentEnv):
def __init__(self, **kwargs):
super().__init__(loadfile=SunBlockingGroup1Env.LOADFILE_I1, **kwargs)

class SB1_E1_I2_Env(SB1_E1_ParentEnv):
def __init__(self, **kwargs):
super().__init__(loadfile=SunBlockingGroup1Env.LOADFILE_I2, **kwargs)

class SB1_E1_I3_Env(SB1_E1_ParentEnv):
def __init__(self, **kwargs):
super().__init__(loadfile=SunBlockingGroup1Env.LOADFILE_I3, **kwargs)

class SB1_E1_I4_Env(SB1_E1_ParentEnv):
def __init__(self, **kwargs):
super().__init__(loadfile=SunBlockingGroup1Env.LOADFILE_I4, **kwargs)
77 changes: 66 additions & 11 deletions src/kspdg/sb1/sb1_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

import gymnasium as gym
import numpy as np

from typing import List, Dict

import kspdg.utils.utils as U
from kspdg.pe1.pe1_base import PursuitEvadeGroup1Env

DEFAULT_EPISODE_TIMEOUT = 240.0 # [sec]
DEFAULT_TARGET_VIEWING_DISTANCE = 50.0 # [m]
DEFAULT_REWARD_DECAY_COEF = 0.001 # [1/m^2]

class SunBlockingGroup1Env(PursuitEvadeGroup1Env):
'''
Expand All @@ -20,6 +28,8 @@ class SunBlockingGroup1Env(PursuitEvadeGroup1Env):

def __init__(self, loadfile:str,
episode_timeout:float = DEFAULT_EPISODE_TIMEOUT,
target_viewing_distance:float = DEFAULT_TARGET_VIEWING_DISTANCE,
reward_decay_coef:float = DEFAULT_REWARD_DECAY_COEF,
**kwargs):
"""
Args:
Expand All @@ -36,10 +46,18 @@ def __init__(self, loadfile:str,
**kwargs
)

assert target_viewing_distance > 0.0
assert reward_decay_coef > 0.0
self.target_viewing_distance = target_viewing_distance
self.reward_decay_coef = reward_decay_coef

# overwrite the pe1 observation space
# to include sun position information
# (see get_observation for mapping)
self.observation_space = NotImplemented
self.observation_space = gym.spaces.Box(
low = np.concatenate((np.zeros(3), -np.inf*np.ones(15))),
high = np.inf * np.ones(18)
)

# don't call reset. This allows instantiation and partial testing
# without connecting to krpc server
Expand All @@ -48,7 +66,8 @@ def _reset_episode_metrics(self) -> None:
""" Reset attributes that track proximity, timing, and propellant use metrics
"""

raise NotImplementedError()
# TODO: customize this for sun-blocking beyond pure pursuit-evade
super()._reset_episode_metrics()

def get_reward(self) -> float:
""" Compute reward value
Expand All @@ -59,7 +78,24 @@ def get_reward(self) -> float:
rew : float
reward at current step
"""
raise NotImplementedError

# get evader position, distance, and unit vector relative to pursuer
p_vesE_vesP__lhpbody = self.vesEvade.position(self.vesPursue.reference_frame)
d_vesE_vesP = np.linalg.norm(p_vesE_vesP__lhpbody)
u_vesE_vesP__lhpbody = p_vesE_vesP__lhpbody/d_vesE_vesP

# get sun unit vector relative to pursuer
p_sun_vesP__lhpbody = self.conn.space_center.bodies['Sun'].position(
self.vesPursue.reference_frame)
d_sun_vesP = np.linalg.norm(p_sun_vesP__lhpbody)
u_sun_vesP__lhpbody = p_sun_vesP__lhpbody/d_sun_vesP

# compute reward. See sb_objective_plot.py for intuition
# about reward surface shape
rew = -np.dot(u_vesE_vesP__lhpbody, u_sun_vesP__lhpbody)
rew *= np.exp(-self.reward_decay_coef * (d_vesE_vesP - self.target_viewing_distance)**2)

return rew

def get_info(self, observation: List, done: bool) -> Dict:
"""compute performance metrics
Expand All @@ -70,27 +106,46 @@ def get_info(self, observation: List, done: bool) -> Dict:
True if last step of episode
"""

raise NotImplementedError
# TODO: customize this for sun-blocking beyond pure pursuit-evade
return super().get_info(observation=observation, done=done)

def get_observation(self):
''' return observation of pursuit and evader vessels from referee ref frame
""" return observation of pursuit and evader vessels from referee ref frame
Returns:
obs : list
[0] : mission elapsed time [s]
[1] : current vehicle (pursuer) mass [kg]
[2] : current vehicle (pursuer) propellant (mono prop) [kg]
[3:6] : pursuer position in reference orbit right-hand CBCI coords [m]
[6:9] : pursuer velocity in reference orbit right-hand CBCI coords [m/s]
[9:12] : evader position in reference orbit right-hand CBCI coords [m]
[12:15] : evader velocity in reference orbit right-hand CBCI coords [m/s]
[3:6] : pursuer position wrt CB in right-hand CBCI coords [m]
[6:9] : pursuer velocity wrt CB in right-hand CBCI coords [m/s]
[9:12] : evader position wrt CB in right-hand CBCI coords [m]
[12:15] : evader velocity wrt CB in right-hand CBCI coords [m/s]
[15:18] : pursuer's sun-pointing unit vector in right-hand CBCI coords [-]
Ref:
- CBCI stands for celestial-body-centered inertial which is a coralary to ECI coords
(see notation: https://github.com/mit-ll/spacegym-kspdg#code-notation)
- KSP's body-centered inertial reference frame is left-handed
(see https://krpc.github.io/krpc/python/api/space-center/celestial-body.html#SpaceCenter.CelestialBody.non_rotating_reference_frame)
'''
"""

obs = super().get_observation()

# get sun position relative to pursuer
p_sun_vesP__lhcbci = np.array(self.conn.space_center.bodies['Sun'].position(
self.vesPursue.orbit.body.non_rotating_reference_frame))

# convert to right-handed CBCI coords
p_sun_vesP__rhcbci = U.convert_lhcbci_to_rhcbci(p_sun_vesP__lhcbci)
d_sun_vesP = np.linalg.norm(p_sun_vesP__rhcbci)
u_sun_vesP__rhcbci = p_sun_vesP__rhcbci/d_sun_vesP

# encode into observation
obs.extend(u_sun_vesP__rhcbci)

return obs



raise NotImplementedError

0 comments on commit 6349298

Please sign in to comment.