diff --git a/rlskyjo/environment/skyjo_env.py b/rlskyjo/environment/skyjo_env.py index b85a679..897c9a7 100644 --- a/rlskyjo/environment/skyjo_env.py +++ b/rlskyjo/environment/skyjo_env.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict import numpy as np from gym import spaces @@ -153,32 +153,80 @@ def __init__( # end PettingZoo API stuff def observation_space(self, agent): - """part of the PettingZoo API""" + """ + observations are: + (1,) lowest sum of players, calculated feature + (1,) lowest number of unmasked cards of any player, + calculated feature + (15,) counts of cards past discard pile cards & open player cards, + calculated feature + (1,) top discard pile card + (1,) current hand_card + total: (19,) + + if observe_other_player_indirect is True: + # constant for any num_players + (12) own cards + total: (31,) + elif observe_other_player_indirect is False: + (num_players*4*3,) + total: (19+12*num_players,) + + Args: + agent ([type]): agent string + + Returns: + gym.space: observation_space of agent + """ return self._observation_spaces[agent] def action_space(self, agent): - """part of the PettingZoo API""" + """part of the PettingZoo API + action_space is Discrete(26): + 0-11: place hand card to position 0-11 + 12-23: discard place hand card and reveal position 0-11 + 24: pick hand card from drawpile + 25: pick hand card from discard pile + + Args: + agent ([type]): [description] + + Returns: + [type]: [description] + """ return self._action_spaces[agent] - def observe(self, agent) -> dict: + def observe(self, agent: str) -> Dict[str,np.ndarray]: """ get observation and action mask from environment - part of the PettingZoo API + part of the PettingZoo API] + + Args: + agent ([str]): agent string + + Returns: + dict: {"observations": np.ndarray, "action_mask": np.ndarray} """ + obs, action_mask = self.table.collect_observation( self._name_to_player_id(agent) ) return {"observations": obs, "action_mask": action_mask} def step(self, action: int) -> None: - """ - action is number from 0-25: - 0-11: place hand card to position 0-11 - 12-23: discard place hand card and reveal position 0-11 - 24: pick hand card from drawpile - 25: pick hand card from discard pile - part of the PettingZoo API - """ + """part of the PettingZoo API + + Args: + action (int): + action is number from 0-25: + 0-11: place hand card to position 0-11 + 12-23: discard place hand card and reveal position 0-11 + 24: pick hand card from drawpile + 25: pick hand card from discard pile + + Returns: + None: + """ current_agent = self.agent_selection player_id = self._name_to_player_id(current_agent) @@ -203,7 +251,7 @@ def step(self, action: int) -> None: self._clear_rewards() self._dones_step_first() - def reset(self): + def reset(self) -> None: """ reset the environment part of the PettingZoo API @@ -218,23 +266,26 @@ def reset(self): self.dones = self._convert_to_dict([False for _ in range(self.num_agents)]) self.infos = {i: {} for i in self.agents} - def render(self, mode="human"): - """render board of the game + def render(self, mode="human") -> None: + """render board of the game to stdout part of the PettingZoo API""" if mode == "human": print(self.table.render_table()) - def close(self): + def close(self) -> None: """part of the PettingZoo API""" pass - def seed(self, seed=None): - """ - seed the environment. + def seed(self, seed: int = None) -> None: + """seed the environment. does not affect global np.random.seed() + experimental. only works with Numba installed part of the PettingZoo API - """ + + Args: + seed (int, optional): [description]. Defaults to None. + """ if seed is not None: self.table.set_seed(seed) @@ -245,7 +296,7 @@ def _calc_final_rewards( """ get reward from score. reward is relative performance to average score - mean reward is 1 + default mean reward is self.mean_reward == 1 args: game_results: dict['str': np.array of len(players) e.g. np.array([35,65,50]) @@ -262,7 +313,14 @@ def _calc_final_rewards( @staticmethod def _name_to_player_id(name: str) -> int: - """convert agent name to int e.g. player_1 to int(1)""" + """[convert agent name to int e.g. player_1 to int(1)] + + Args: + name (str): agent name + + Returns: + int: agent int + """ return int(name.split("_")[-1]) def _convert_to_dict(self, list_of_list): diff --git a/rlskyjo/models/random_admissible_policy.py b/rlskyjo/models/random_admissible_policy.py index 2680aef..ce8a5a5 100644 --- a/rlskyjo/models/random_admissible_policy.py +++ b/rlskyjo/models/random_admissible_policy.py @@ -9,7 +9,16 @@ def policy_ra( rng: Union[None, np.random.Generator] = None, ) -> int: """for demonstration. - picks randomly an admissible action from the action mask""" + picks randomly an admissible action from the action mask + + Args: + observation (np.array): [description] + action_mask (np.array): [description] + rng (Union[None, np.random.Generator], optional): [description]. Defaults to None. + + Returns: + int: [description] + """ if rng is None: module = np.random else: diff --git a/rlskyjo/utils.py b/rlskyjo/utils.py index ddc14b5..19aedbd 100644 --- a/rlskyjo/utils.py +++ b/rlskyjo/utils.py @@ -1,10 +1,12 @@ from pathlib import Path -def get_project_root() -> Path: - """ - return Path to the project directory, top folder of rlskyjo - """ +def get_project_root() -> Path: + """return Path to the project directory, top folder of rlskyjo + + Returns: + Path: Path to the project directory + """ return Path(__file__).parent.parent.resolve()