Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added display of probability distribution advice #38

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions diplomacy/communication/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class CreateGame(_AbstractChannelRequest):
game created and joined. Either a power game (if power name given) or an omniscient game.
"""
__slots__ = ['game_id', 'power_name', 'state', 'map_name', 'rules', 'n_controls', 'deadline',
'registration_password', 'daide_port', 'player_type']
'registration_password', 'daide_port', 'player_type', 'distribution_advice']
params = {
strings.GAME_ID: parsing.OptionalValueType(str),
strings.N_CONTROLS: parsing.OptionalValueType(int),
Expand All @@ -286,7 +286,8 @@ class CreateGame(_AbstractChannelRequest):
strings.MAP_NAME: parsing.DefaultValueType(str, 'standard'),
strings.RULES: parsing.OptionalValueType(parsing.SequenceType(str, sequence_builder=set)),
strings.DAIDE_PORT: parsing.OptionalValueType(int),
strings.PLAYER_TYPE: parsing.OptionalValueType(str)
strings.PLAYER_TYPE: parsing.OptionalValueType(str),
strings.DISTRIBUTION_ADVICE: parsing.DefaultValueType(dict, {})
}

def __init__(self, **kwargs):
Expand All @@ -300,6 +301,7 @@ def __init__(self, **kwargs):
self.rules = set()
self.daide_port = None
self.player_type = ''
self.distribution_advice = {}
super(CreateGame, self).__init__(**kwargs)


Expand Down Expand Up @@ -575,6 +577,23 @@ def __init__(self, **kwargs):
# Game requests.
# ==============

class GetOrderDistribution(_AbstractGameRequest):
"""
Game request to get model prediction (i.e., the probability distribution) of possible orders
for a selected province

:param power_name (str): power that requests the predictions
:param province (str): the province selected by the requested power
:param model(str): the type of model
"""
__slots__ = ['power_name', 'province', 'model']

params = {
'power_name': str,
'province': str,
'model': str
}

class ClearCenters(_AbstractGameRequest):
""" Game request to clear supply centers. See method :meth:`.Game.clear_centers`.

Expand Down
6 changes: 4 additions & 2 deletions diplomacy/communication/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class DataGameInfo(_AbstractResponse):
"""
__slots__ = ['game_id', 'phase', 'timestamp', 'map_name', 'rules', 'status', 'n_players',
'n_controls', 'deadline', 'registration_password', 'observer_level',
'controlled_powers', 'timestamp_created']
'controlled_powers', 'timestamp_created', 'distribution_advice']
params = {
strings.GAME_ID: str,
strings.PHASE: str,
Expand All @@ -141,7 +141,8 @@ class DataGameInfo(_AbstractResponse):
strings.N_PLAYERS: parsing.OptionalValueType(int),
strings.N_CONTROLS: parsing.OptionalValueType(int),
strings.DEADLINE: parsing.OptionalValueType(int),
strings.REGISTRATION_PASSWORD: parsing.OptionalValueType(bool)
strings.REGISTRATION_PASSWORD: parsing.OptionalValueType(bool),
strings.DISTRIBUTION_ADVICE: parsing.DefaultValueType(dict, {})
}

def __init__(self, **kwargs):
Expand All @@ -158,6 +159,7 @@ def __init__(self, **kwargs):
self.n_controls = None # type: int
self.deadline = None # type: int
self.registration_password = None # type: bool
self.distribution_advice = {} # type: dict
super(DataGameInfo, self).__init__(**kwargs)

class DataPossibleOrders(_AbstractResponse):
Expand Down
3 changes: 3 additions & 0 deletions diplomacy/engine/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class Game(Jsonable):

# pylint: disable=too-many-instance-attributes
__slots__ = [
"distribution_advice",
"victory",
"no_rules",
"meta_rules",
Expand Down Expand Up @@ -305,6 +306,7 @@ class Game(Jsonable):
zobrist_tables = {}
rule_cache = ()
model = {
strings.DISTRIBUTION_ADVICE: parsing.DefaultValueType(dict, {}),
strings.CONTROLLED_POWERS: parsing.OptionalValueType(parsing.SequenceType(str)),
strings.DAIDE_PORT: parsing.OptionalValueType(int),
strings.DEADLINE: parsing.DefaultValueType(int, 300),
Expand Down Expand Up @@ -401,6 +403,7 @@ class Game(Jsonable):

def __init__(self, game_id=None, **kwargs):
"""Constructor"""
self.distribution_advice = {}
self.victory = None
self.no_rules = set()
self.meta_rules = []
Expand Down
24 changes: 22 additions & 2 deletions diplomacy/server/request_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from diplomacy.utils.common import hash_password
from diplomacy.utils.constants import OrderSettings
from diplomacy.utils.game_phase_data import GamePhaseData
from diplomacy.utils.models import LogisticRegression

LOGGER = logging.getLogger(__name__)

Expand All @@ -48,6 +49,21 @@

SERVER_GAME_RULES = ['NO_PRESS', 'IGNORE_ERRORS', 'POWER_CHOICE']

DISTRIBUTION_MODELS = {
"standard_lr": LogisticRegression,
}

def on_get_order_distribution(server, request, connection_handler):
"""Manage request GetOrderDistribution.
"""
preds = {}
level = verify_request(server, request, connection_handler, require_master=False)
game_state = level.game.get_state()
requested_province = request.province
model = DISTRIBUTION_MODELS.get(request.model)
if model:
preds = model(game_state, requested_province).predict(top_k=10)
return responses.DataSavedGame(data=preds, request_id=request.request_id)

def on_clear_centers(server, request, connection_handler):
""" Manage request ClearCenters.
Expand Down Expand Up @@ -155,7 +171,9 @@ def on_create_game(server, request, connection_handler):
n_controls=request.n_controls,
deadline=request.deadline,
registration_password=password,
server=server)
server=server,
distribution_advice=request.distribution_advice
)

# Make sure game creator will be a game master (set him as moderator if he's not an admin).
if not server.users.has_admin(username):
Expand Down Expand Up @@ -723,7 +741,8 @@ def on_list_games(server, request, connection_handler):
n_players=server_game.count_controlled_powers(),
n_controls=server_game.get_expected_controls_count(),
deadline=server_game.deadline,
registration_password=bool(server_game.registration_password)
registration_password=bool(server_game.registration_password),
distribution_advice=server_game.distribution_advice
))
return responses.DataGames(data=selected_game_indices, request_id=request.request_id)

Expand Down Expand Up @@ -1344,6 +1363,7 @@ def on_vote(server, request, connection_handler):

# Mapping dictionary from request class to request handler function.
MAPPING = {
requests.GetOrderDistribution: on_get_order_distribution,
requests.ClearCenters: on_clear_centers,
requests.ClearOrders: on_clear_orders,
requests.ClearUnits: on_clear_units,
Expand Down
28 changes: 28 additions & 0 deletions diplomacy/utils/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC

# MODEL MODULE IMPORTS
from baseline_models.model_code.engine_predict import BaselineAdvice

# MODEL PATHS (TODO: fill in the model paths here)
MODEL_PATHS = {
"logistic_regression": ""
}

# MODEL UTILS
class Model(ABC):
def __init__(self, game_state, requested_province):
self.game_state = game_state
self.requested_province = requested_province

def predict(self, top_k=6):
"""
Return the predicted distribution of possible orders at current game state
"""
raise NotImplementedError

# MODELS
class LogisticRegression(Model):
def predict(self, top_k=6):
model = BaselineAdvice(MODEL_PATHS["logistic_regression"], self.game_state, self.requested_province)
return model.predict(top_k)

1 change: 1 addition & 0 deletions diplomacy/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# ==============================================================================
""" Some strings frequently used (to help prevent typos). """

DISTRIBUTION_ADVICE='distribution_advice'
ABBREV = 'abbrev'
ACTIVE = 'active'
ADJUST = 'adjust'
Expand Down
4 changes: 4 additions & 0 deletions diplomacy/web/src/diplomacy/client/channel.js
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ export class Channel {

//// Public game API.

getOrderDistribution(parameters, game){
return this._req('get_order_distribution', undefined, undefined, parameters, game);
}

getAllPossibleOrders(parameters, game) {
return this._req('get_all_possible_orders', undefined, undefined, parameters, game);
}
Expand Down
4 changes: 4 additions & 0 deletions diplomacy/web/src/diplomacy/client/network_game.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ export class NetworkGame {

//// Game requests API.

getOrderDistribution(parameters){
return this._req(Channel.prototype.getOrderDistribution, parameters);
}

getAllPossibleOrders(parameters) {
return this._req(Channel.prototype.getAllPossibleOrders, parameters);
}
Expand Down
3 changes: 3 additions & 0 deletions diplomacy/web/src/diplomacy/client/response_managers.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ export const RESPONSE_MANAGERS = {
delete_game: function (context, response) {
context.deleteGame();
},
get_order_distribution: function (context, response){
return response.data;
},
get_phase_history: function (context, response) {
for (let phaseData of response.data) {
context.game.local.extendPhaseHistory(phaseData);
Expand Down
4 changes: 3 additions & 1 deletion diplomacy/web/src/diplomacy/communication/requests.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ export const REQUESTS = {
level: STRINGS.CHANNEL,
model: {
game_id: null, n_controls: null, deadline: 300, registration_password: null,
power_name: null, state: null, map_name: 'standard', rules: null, player_type: null
power_name: null, state: null, map_name: 'standard', rules: null, player_type: null,
distribution_advice: {}
}
},
delete_account: {level: STRINGS.CHANNEL, model: {username: null}},
get_order_distribution: {level: STRINGS.GAME, model: {power_name: null, province: null, model: null}},
get_all_possible_orders: {level: STRINGS.GAME, model: {}},
get_available_maps: {level: STRINGS.CHANNEL, model: {}},
get_playable_powers: {level: STRINGS.CHANNEL, model: {game_id: null}},
Expand Down
1 change: 1 addition & 0 deletions diplomacy/web/src/diplomacy/engine/game.js
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ export class Game {
// {loc => order type}
this.orderableLocToTypes = null;
this.client = null; // Used as pointer to a NetworkGame.
this.distribution_advice = gameData.distribution_advice ? gameData?.distribution_advice : {};
}

get n_players() {
Expand Down
6 changes: 4 additions & 2 deletions diplomacy/web/src/gui/maps/common/build.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import PropTypes from "prop-types";

export class Build extends React.Component {
render() {
const opacity = (this.props?.opacity === undefined ? 1 : this.props?.opacity);
const Coordinates = this.props.coordinates;
const SymbolSizes = this.props.symbolSizes;
const loc = this.props.loc;
Expand All @@ -31,7 +32,7 @@ export class Build extends React.Component {

const symbol = unit_type === 'A' ? ARMY : FLEET;
return (
<g>
<g opacity={opacity}>
<use x={build_loc_x}
y={build_loc_y}
height={SymbolSizes[build_symbol].height}
Expand All @@ -53,5 +54,6 @@ Build.propTypes = {
loc: PropTypes.string.isRequired,
powerName: PropTypes.string.isRequired,
coordinates: PropTypes.object.isRequired,
symbolSizes: PropTypes.object.isRequired
symbolSizes: PropTypes.object.isRequired,
opacity: PropTypes.number
};
12 changes: 12 additions & 0 deletions diplomacy/web/src/gui/maps/common/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ export function setInfluence(classes, mapData, loc, power_name) {
classes[id] = power_name ? power_name.toLowerCase() : 'nopower';
}

export function setInfluenceLightBackground(classes, mapData, loc, power_name) {
const province = mapData.getProvince(loc);
if (!province)
throw new Error(`Unable to find province ${loc}`);
if (!['LAND', 'COAST'].includes(province.type))
return;
const id = province.getID(classes);
if (!id)
throw new Error(`Unable to find SVG path for loc ${id}`);
classes[id] = power_name ? `${power_name.toLowerCase()}light` : 'nopowerlight';
}

export function getClickedID(event) {
let node = event.target;
if (!node.id && node.parentNode.id && node.parentNode.tagName === 'g')
Expand Down
8 changes: 6 additions & 2 deletions diplomacy/web/src/gui/maps/common/convoy.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import PropTypes from "prop-types";

export class Convoy extends React.Component {
render() {
const opacity = (this.props?.opacity === undefined ? 1 : this.props?.opacity);
const Coordinates = this.props.coordinates;
const SymbolSizes = this.props.symbolSizes;
const Colors = this.props.colors;
Expand Down Expand Up @@ -56,7 +57,9 @@ export class Convoy extends React.Component {
dest_loc_y = '' + Math.round((parseFloat(src_loc_y) + (dest_vector_length - delta_dec) / dest_vector_length * dest_delta_y) * 100.) / 100.;

return (
<g stroke={Colors[this.props.powerName]}>
<g stroke={Colors[this.props.powerName]}
opacity={opacity}
>
<line x1={loc_x}
y1={loc_y}
x2={src_loc_x_1}
Expand Down Expand Up @@ -99,5 +102,6 @@ Convoy.propTypes = {
powerName: PropTypes.string.isRequired,
coordinates: PropTypes.object.isRequired,
symbolSizes: PropTypes.object.isRequired,
colors: PropTypes.object.isRequired
colors: PropTypes.object.isRequired,
opacity: PropTypes.number
};
6 changes: 4 additions & 2 deletions diplomacy/web/src/gui/maps/common/disband.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ import PropTypes from "prop-types";

export class Disband extends React.Component {
render() {
const opacity = (this.props?.opacity === undefined ? 1 : this.props?.opacity);
const Coordinates = this.props.coordinates;
const SymbolSizes = this.props.symbolSizes;
const loc = this.props.loc;
const phaseType = this.props.phaseType;
const symbol = 'RemoveUnit';
const [loc_x, loc_y] = centerSymbolAroundUnit(Coordinates, SymbolSizes, loc, phaseType === 'R', symbol);
return (
<g>
<g opacity={opacity}>
<use x={loc_x}
y={loc_y}
height={SymbolSizes[symbol].height}
Expand All @@ -43,5 +44,6 @@ Disband.propTypes = {
loc: PropTypes.string.isRequired,
phaseType: PropTypes.string.isRequired,
coordinates: PropTypes.object.isRequired,
symbolSizes: PropTypes.object.isRequired
symbolSizes: PropTypes.object.isRequired,
opacity: PropTypes.number
};
8 changes: 6 additions & 2 deletions diplomacy/web/src/gui/maps/common/hold.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ import PropTypes from "prop-types";

export class Hold extends React.Component {
render() {
const opacity = (this.props?.opacity === undefined ? 1 : this.props?.opacity);
const Coordinates = this.props.coordinates;
const Colors = this.props.colors;
const SymbolSizes = this.props.symbolSizes;
const symbol = 'HoldUnit';
const [loc_x, loc_y] = centerSymbolAroundUnit(Coordinates, SymbolSizes, this.props.loc, false, symbol);
return (
<g stroke={Colors[this.props.powerName]}>
<g stroke={Colors[this.props.powerName]}
opacity={opacity}
>
<use
x={loc_x}
y={loc_y}
Expand All @@ -43,5 +46,6 @@ Hold.propTypes = {
powerName: PropTypes.string.isRequired,
coordinates: PropTypes.object.isRequired,
symbolSizes: PropTypes.object.isRequired,
colors: PropTypes.object.isRequired
colors: PropTypes.object.isRequired,
opacity: PropTypes.number
};
6 changes: 4 additions & 2 deletions diplomacy/web/src/gui/maps/common/move.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import PropTypes from "prop-types";

export class Move extends React.Component {
render() {
const opacity = (this.props?.opacity === undefined ? 1 : this.props?.opacity);
const Coordinates = this.props.coordinates;
const SymbolSizes = this.props.symbolSizes;
const Colors = this.props.colors;
Expand All @@ -36,7 +37,7 @@ export class Move extends React.Component {
dest_loc_x = '' + Math.round((parseFloat(src_loc_x) + (vector_length - delta_dec) / vector_length * delta_x) * 100.) / 100.;
dest_loc_y = '' + Math.round((parseFloat(src_loc_y) + (vector_length - delta_dec) / vector_length * delta_y) * 100.) / 100.;
return (
<g>
<g opacity={opacity}>
<line x1={src_loc_x}
y1={src_loc_y}
x2={dest_loc_x}
Expand All @@ -63,5 +64,6 @@ Move.propTypes = {
phaseType: PropTypes.string.isRequired,
coordinates: PropTypes.object.isRequired,
symbolSizes: PropTypes.object.isRequired,
colors: PropTypes.object.isRequired
colors: PropTypes.object.isRequired,
opacity: PropTypes.number
};
Loading