Skip to content

Commit

Permalink
Don't pass env into _EnvPlayer instantiation (hsahovic#672)
Browse files Browse the repository at this point in the history
* initial commit

* fix tests

* fix test

* unused import

* polish

* settle down

* condense

* action_to_order

* black

---------

Co-authored-by: Haris Sahovic <[email protected]>
  • Loading branch information
cameronangliss and hsahovic authored Jan 25, 2025
1 parent d6b0474 commit 8bb067e
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 136 deletions.
2 changes: 1 addition & 1 deletion examples/gymnasium_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def describe_embedding(self):
dtype=np.float64,
)

def action_to_move(self, action, battle):
def action_to_order(self, action, battle):
return self.agent.choose_random_move(battle)

def calc_reward(self, battle):
Expand Down
22 changes: 12 additions & 10 deletions src/poke_env/player/env_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import List, Optional, Union
from weakref import WeakKeyDictionary

import numpy as np

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player.battle_order import BattleOrder, ForfeitBattleOrder
from poke_env.player.gymnasium_api import GymnasiumEnv
Expand Down Expand Up @@ -224,7 +226,7 @@ class Gen4EnvSinglePlayer(EnvPlayer):
_ACTION_SPACE = list(range(4 + 6))
_DEFAULT_BATTLE_FORMAT = "gen4randombattle"

def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
def action_to_order(self, action: np.int64, battle: AbstractBattle) -> BattleOrder:
"""Converts actions to move orders.
The conversion is done as follows:
Expand All @@ -239,7 +241,7 @@ def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
If the proposed action is illegal, a random legal move is performed.
:param action: The action to convert.
:type action: int
:type action: np.int64
:param battle: The battle in which to act.
:type battle: Battle
:return: the order to send to the server.
Expand Down Expand Up @@ -267,7 +269,7 @@ class Gen6EnvSinglePlayer(EnvPlayer):
_ACTION_SPACE = list(range(2 * 4 + 6))
_DEFAULT_BATTLE_FORMAT = "gen6randombattle"

def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
def action_to_order(self, action: np.int64, battle: AbstractBattle) -> BattleOrder:
"""Converts actions to move orders.
The conversion is done as follows:
Expand All @@ -285,7 +287,7 @@ def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
If the proposed action is illegal, a random legal move is performed.
:param action: The action to convert.
:type action: int
:type action: np.int64
:param battle: The battle in which to act.
:type battle: Battle
:return: the order to send to the server.
Expand Down Expand Up @@ -315,7 +317,7 @@ class Gen7EnvSinglePlayer(EnvPlayer):
_ACTION_SPACE = list(range(3 * 4 + 6))
_DEFAULT_BATTLE_FORMAT = "gen7randombattle"

def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
def action_to_order(self, action: np.int64, battle: AbstractBattle) -> BattleOrder:
"""Converts actions to move orders.
The conversion is done as follows:
Expand All @@ -336,7 +338,7 @@ def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
If the proposed action is illegal, a random legal move is performed.
:param action: The action to convert.
:type action: int
:type action: np.int64
:param battle: The battle in which to act.
:type battle: Battle
:return: the order to send to the server.
Expand Down Expand Up @@ -375,7 +377,7 @@ class Gen8EnvSinglePlayer(EnvPlayer):
_ACTION_SPACE = list(range(4 * 4 + 6))
_DEFAULT_BATTLE_FORMAT = "gen8randombattle"

def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
def action_to_order(self, action: np.int64, battle: AbstractBattle) -> BattleOrder:
"""Converts actions to move orders.
The conversion is done as follows:
Expand All @@ -402,7 +404,7 @@ def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
If the proposed action is illegal, a random legal move is performed.
:param action: The action to convert.
:type action: int
:type action: np.int64
:param battle: The battle in which to act.
:type battle: Battle
:return: the order to send to the server.
Expand Down Expand Up @@ -449,7 +451,7 @@ class Gen9EnvSinglePlayer(EnvPlayer):
_ACTION_SPACE = list(range(5 * 4 + 6))
_DEFAULT_BATTLE_FORMAT = "gen9randombattle"

def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
def action_to_order(self, action: np.int64, battle: AbstractBattle) -> BattleOrder:
"""Converts actions to move orders.
The conversion is done as follows:
Expand Down Expand Up @@ -479,7 +481,7 @@ def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
If the proposed action is illegal, a random legal move is performed.
:param action: The action to convert.
:type action: int
:type action: np.int64
:param battle: The battle in which to act.
:type battle: Battle
:return: the order to send to the server.
Expand Down
135 changes: 68 additions & 67 deletions src/poke_env/player/gymnasium_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
import asyncio
import time
from abc import abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import numpy as np
from gymnasium.spaces import Discrete, Space
from pettingzoo.utils.env import ( # type: ignore[import-untyped]
ActionType,
ObsType,
ParallelEnv,
)
from pettingzoo.utils.env import ObsType, ParallelEnv # type: ignore[import-untyped]

from poke_env.concurrency import POKE_LOOP, create_in_poke_loop
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player.battle_order import BattleOrder, ForfeitBattleOrder
from poke_env.player.battle_order import (
BattleOrder,
DefaultBattleOrder,
ForfeitBattleOrder,
)
from poke_env.player.player import Player
from poke_env.ps_client import AccountConfiguration
from poke_env.ps_client.server_configuration import (
Expand All @@ -27,27 +28,32 @@
)
from poke_env.teambuilder.teambuilder import Teambuilder

ItemType = TypeVar("ItemType")


class _AsyncQueue:
def __init__(self, queue: asyncio.Queue[Any]):
class _AsyncQueue(Generic[ItemType]):
def __init__(self, queue: asyncio.Queue[ItemType]):
self.queue = queue

async def async_get(self):
async def async_get(self) -> ItemType:
return await self.queue.get()

def get(self, timeout: Optional[float] = None, default: Any = None):
def get(
self, timeout: Optional[float] = None, default: Optional[ItemType] = None
) -> ItemType:
try:
res = asyncio.run_coroutine_threadsafe(
asyncio.wait_for(self.async_get(), timeout), POKE_LOOP
)
return res.result()
except asyncio.TimeoutError:
assert default is not None
return default

async def async_put(self, item: Any):
async def async_put(self, item: ItemType):
await self.queue.put(item)

def put(self, item: Any):
def put(self, item: ItemType):
task = asyncio.run_coroutine_threadsafe(self.queue.put(item), POKE_LOOP)
task.result()

Expand All @@ -62,24 +68,22 @@ async def async_join(self):
await self.queue.join()


class _AsyncPlayer(Player):
actions: _AsyncQueue
observations: _AsyncQueue
class _EnvPlayer(Player):
battle_queue: _AsyncQueue[AbstractBattle]
order_queue: _AsyncQueue[BattleOrder]

def __init__(
self,
user_funcs: GymnasiumEnv,
username: str,
**kwargs: Any,
):
self.__class__.__name__ = username
super().__init__(**kwargs)
self.__class__.__name__ = "_AsyncPlayer"
self.observations = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1))
self.actions = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1))
self.__class__.__name__ = "_EnvPlayer"
self.battle_queue = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1))
self.order_queue = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1))
self.battle: Optional[AbstractBattle] = None
self.waiting = False
self._user_funcs = user_funcs

def choose_move(self, battle: AbstractBattle) -> Awaitable[BattleOrder]:
return self._env_move(battle)
Expand All @@ -89,21 +93,17 @@ async def _env_move(self, battle: AbstractBattle) -> BattleOrder:
self.battle = battle
if not self.battle == battle:
raise RuntimeError("Using different battles for queues")
battle_to_send = self._user_funcs.embed_battle(battle)
await self.observations.async_put(battle_to_send)
await self.battle_queue.async_put(battle)
self.waiting = True
action = await self.actions.async_get()
action = await self.order_queue.async_get()
self.waiting = False
if action == -1:
return ForfeitBattleOrder()
return self._user_funcs.action_to_move(action, battle)
return action

def _battle_finished_callback(self, battle: AbstractBattle):
to_put = self._user_funcs.embed_battle(battle)
asyncio.run_coroutine_threadsafe(self.observations.async_put(to_put), POKE_LOOP)
asyncio.run_coroutine_threadsafe(self.battle_queue.async_put(battle), POKE_LOOP)


class GymnasiumEnv(ParallelEnv[str, ObsType, ActionType]):
class GymnasiumEnv(ParallelEnv[str, ObsType, np.int64]):
"""
Base class implementing the Gymnasium API on the main thread.
"""
Expand Down Expand Up @@ -183,8 +183,7 @@ def __init__(
leave it inactive.
:type start_challenging: bool
"""
self.agent1 = _AsyncPlayer(
self,
self.agent1 = _EnvPlayer(
username=self.__class__.__name__, # type: ignore
account_configuration=account_configuration1,
avatar=avatar,
Expand All @@ -201,8 +200,7 @@ def __init__(
ping_timeout=ping_timeout,
team=team,
)
self.agent2 = _AsyncPlayer(
self,
self.agent2 = _EnvPlayer(
username=self.__class__.__name__, # type: ignore
account_configuration=account_configuration2,
avatar=avatar,
Expand All @@ -226,10 +224,6 @@ def __init__(
self.action_spaces = {
name: Discrete(self.action_space_size()) for name in self.possible_agents
}
self._actions1 = self.agent1.actions
self._observations1 = self.agent1.observations
self._actions2 = self.agent2.actions
self._observations2 = self.agent2.observations
self.battle1: Optional[AbstractBattle] = None
self.battle2: Optional[AbstractBattle] = None
self._keep_challenging: bool = False
Expand All @@ -245,7 +239,7 @@ def __init__(
# PettingZoo API
# https://pettingzoo.farama.org/api/parallel/#parallelenv

def step(self, actions: Dict[str, ActionType]) -> Tuple[
def step(self, actions: Dict[str, np.int64]) -> Tuple[
Dict[str, ObsType],
Dict[str, float],
Dict[str, bool],
Expand All @@ -257,18 +251,19 @@ def step(self, actions: Dict[str, ActionType]) -> Tuple[
if self.battle1.finished:
raise RuntimeError("Battle is already finished, call reset")
if self.agent1.waiting:
self._actions1.put(actions[self.agents[0]])
order1 = self.action_to_order(actions[self.agents[0]], self.battle1)
self.agent1.order_queue.put(order1)
if self.agent2.waiting:
self._actions2.put(actions[self.agents[1]])
order2 = self.action_to_order(actions[self.agents[1]], self.battle2)
self.agent2.order_queue.put(order2)
observations = {
self.agents[0]: self._observations1.get(
self.agents[0]: self.agent1.battle_queue.get(
timeout=0.1, default=self.embed_battle(self.battle1)
),
self.agents[1]: self._observations2.get(
self.agents[1]: self.agent2.battle_queue.get(
timeout=0.1, default=self.embed_battle(self.battle2)
),
}
assert self.battle1 == self.agent1.battle
reward = {
self.agents[0]: self.calc_reward(self.battle1),
self.agents[1]: self.calc_reward(self.battle2),
Expand All @@ -279,6 +274,7 @@ def step(self, actions: Dict[str, ActionType]) -> Tuple[
truncated = {self.agents[0]: trunc1, self.agents[1]: trunc2}
if self.battle1.finished:
self.agents = []
assert self.battle1 == self.agent1.battle
return observations, reward, terminated, truncated, self.get_additional_info()

def reset(
Expand All @@ -297,19 +293,21 @@ def reset(
time.sleep(self._TIME_BETWEEN_RETRIES)
if self.battle1 and not self.battle1.finished:
if self.battle1 == self.agent1.battle:
self._actions1.put(-1)
self._actions2.put(0)
self._observations1.get()
self._observations2.get()
self.agent1.order_queue.put(ForfeitBattleOrder())
self.agent2.order_queue.put(DefaultBattleOrder())
self.agent1.battle_queue.get()
self.agent2.battle_queue.get()
else:
raise RuntimeError(
"Environment and agent aren't synchronized. Try to restart"
)
while self.battle1 == self.agent1.battle:
time.sleep(0.01)
obs1 = self.agent1.battle_queue.get()
obs2 = self.agent2.battle_queue.get()
observations = {
self.agents[0]: self._observations1.get(),
self.agents[1]: self._observations2.get(),
self.agents[0]: self.embed_battle(obs1),
self.agents[1]: self.embed_battle(obs2),
}
self.battle1 = self.agent1.battle
self.battle1.logger = None
Expand Down Expand Up @@ -381,7 +379,7 @@ def calc_reward(self, battle: AbstractBattle) -> float:
pass

@abstractmethod
def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
def action_to_order(self, action: np.int64, battle: AbstractBattle) -> BattleOrder:
"""
Returns the BattleOrder relative to the given action.
Expand Down Expand Up @@ -568,20 +566,23 @@ async def _stop_challenge_loop(

if force:
if self.battle1 and not self.battle1.finished:
if not (self._actions1.empty() and self._actions2.empty()):
if not (self.agent1.battle_queue.empty() and self.agent2.battle_queue.empty()):
await asyncio.sleep(2)
if not (self._actions1.empty() and self._actions2.empty()):
if not (
self.agent1.order_queue.empty()
and self.agent2.order_queue.empty()
):
raise RuntimeError(
"The agent is still sending actions. "
"Use this method only when training or "
"evaluation are over."
)
if not self._observations1.empty():
await self._observations1.async_get()
if not self._observations2.empty():
await self._observations2.async_get()
await self._actions1.async_put(-1)
await self._actions2.async_put(0)
if not self.agent1.battle_queue.empty():
await self.agent1.battle_queue.async_get()
if not self.agent2.battle_queue.empty():
await self.agent2.battle_queue.async_get()
await self.agent1.order_queue.async_put(ForfeitBattleOrder())
await self.agent2.order_queue.async_put(DefaultBattleOrder())

if wait and self._challenge_task:
while not self._challenge_task.done():
Expand All @@ -593,14 +594,14 @@ async def _stop_challenge_loop(
self.battle2 = None
self.agent1.battle = None
self.agent2.battle = None
while not self._actions1.empty():
await self._actions1.async_get()
while not self._actions2.empty():
await self._actions2.async_get()
while not self._observations1.empty():
await self._observations1.async_get()
while not self._observations2.empty():
await self._observations2.async_get()
while not self.agent1.order_queue.empty():
await self.agent1.order_queue.async_get()
while not self.agent2.order_queue.empty():
await self.agent2.order_queue.async_get()
while not self.agent1.battle_queue.empty():
await self.agent1.battle_queue.async_get()
while not self.agent2.battle_queue.empty():
await self.agent2.battle_queue.async_get()

if purge:
self.reset_battles()
Expand Down
Loading

0 comments on commit 8bb067e

Please sign in to comment.