Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
Human Game update

Update  tracking deception to output in TACC log

debug PO corresponding threshold

fix issues

uremove daide translation from human games (game type 2set input invalidation error

update after running dec_nodec

print update
  • Loading branch information
wwongkamjan authored and wwongkam committed May 15, 2023
1 parent 54078e5 commit 3aa5a9e
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 36 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ After each pull it's recommended to run `make` to re-compile internal C++ and pr
module load tacc-singularity
git clone --recursive https://github.com/ALLAN-DIP/diplomacy_cicero.git
cp -r /corral/projects/DARPA-SHADE/Shared/cicero "$WORK"
cp -r /corral/projects/DARPA-SHADE/Shared/cicero $WORK
cp /corral/projects/DARPA-SHADE/Shared/UMD/pytorch_model.bin "$WORK"/diplomacy_cicero/fairdiplomacy/AMR/amrlib/amrlib/data/model_parse_xfm/checkpoint-9920/
export CICERO=$WORK/cicero
cd "$CICERO"
cd $CICERO
singularity run --nv \
--bind "$WORK"/diplomacy_cicero/fairdiplomacy/agents/:/diplomacy_cicero/fairdiplomacy/agent \
--bind "$WORK"/diplomacy_cicero/fairdiplomacy_external:/diplomacy_cicero/fairdiplomacy_external \
Expand Down
18 changes: 17 additions & 1 deletion fairdiplomacy/agents/searchbot_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ def _generate_message_via_message_search(
)
meta_annotations.add_filtered_msg(data, msg_dict["time_sent"])
meta_annotations.after_message_generation_failed(bad_tags=[TOKEN_DETAILS_TAG])

self.recent_pseudo_orders = pseudo_orders
return selected_msg_dict

def _rank_candidate_messages(
Expand Down Expand Up @@ -1365,6 +1365,22 @@ def generate_message(

timings.stop()
timings.pprint(logging.getLogger("timings").info)

if maybe_msg_dict is not None and 'message' in maybe_msg_dict:
with timings.create_subcontext("po") as subtimings:
pseudo_orders = self.get_pseudo_orders(
game, power=power, state=state, recipient=recipient, timings=subtimings,
)
(corresponds_to_pseudo, extra_corr_info,) = self.message_handler.message_filterer._corresponds_to_pseudo_orders(
maybe_msg_dict, game, pseudo_orders,
)
if 'diff' in extra_corr_info:
print(f'in searchbot_agent.py message: {maybe_msg_dict} and deceptive info: {extra_corr_info}')
if extra_corr_info['diff'] >=extra_corr_info['thresh']:
maybe_msg_dict['deceptive'] = f"A truth to Cicero: {maybe_msg_dict['message']}"
else:
maybe_msg_dict['deceptive'] = f"A lie to Cicero: {maybe_msg_dict['message']}"

return maybe_msg_dict

def _get_phase_pseudo_orders(
Expand Down
128 changes: 101 additions & 27 deletions fairdiplomacy_external/mila_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
MESSAGE_DELAY_IF_SLEEP_INF = Timestamp.from_seconds(60)
ProtoMessage = google.protobuf.message.Message

DEFAULT_DEADLINE = 4
DEFAULT_DEADLINE = 5

import json
sys.path.insert(0, '/diplomacy_cicero/fairdiplomacy/AMR/DAIDE/DiplomacyAMR/code')
Expand All @@ -113,7 +113,7 @@

class milaWrapper:

def __init__(self, is_daide):
def __init__(self, is_deceptive):
self.game: NetworkGame = None
self.dipcc_game: Game = None
self.prev_state = 0 # number of number received messages in the current phase
Expand All @@ -126,9 +126,15 @@ def __init__(self, is_daide):
self.sent_FCT = {'RUSSIA':set(),'TURKEY':set(),'ITALY':set(),'ENGLAND':set(),'FRANCE':set(),'GERMANY':set(),'AUSTRIA':set()}
self.sent_PRP = {'RUSSIA':set(),'TURKEY':set(),'ITALY':set(),'ENGLAND':set(),'FRANCE':set(),'GERMANY':set(),'AUSTRIA':set()}
self.last_PRP_review_timestamp = {'RUSSIA':0,'TURKEY':0,'ITALY':0,'ENGLAND':0,'FRANCE':0,'GERMANY':0,'AUSTRIA':0}
self.daide = is_daide
self.last_comm_intent={'RUSSIA':None,'TURKEY':None,'ITALY':None,'ENGLAND':None,'FRANCE':None,'GERMANY':None,'AUSTRIA':None,'final':None}
self.deceptive = is_deceptive

agent_config = heyhi.load_config('/diplomacy_cicero/conf/common/agents/cicero.prototxt')
if self.deceptive:
agent_config = heyhi.load_config('/diplomacy_cicero/conf/common/agents/cicero_lie.prototxt')
print('CICERO deceptive')
else:
agent_config = heyhi.load_config('/diplomacy_cicero/conf/common/agents/cicero.prototxt')
print('Cicero non-deceptive')
print(f"successfully load cicero config")

self.agent = PyBQRE1PAgent(agent_config.bqre1p)
Expand All @@ -139,13 +145,16 @@ async def play_mila(
port: int,
game_id: str,
power_name: str,
game_type: int,
gamedir: Path,
) -> None:


self.power_name = power_name
print(f"Antony joining game: {game_id} as {power_name}")
connection = await connect(hostname, port)
dec = 'Deceptive_' if self.deceptive else ''
channel = await connection.authenticate(
f"Antony_{power_name}", "password"
f"{dec}Antony_{power_name}", "password"
)
self.game: NetworkGame = await channel.join_game(game_id=game_id, power_name=power_name)

Expand All @@ -161,15 +170,15 @@ async def play_mila(
print(f"Started dipcc game")

self.player = Player(self.agent, power_name)

num_beams = 4
batch_size = 16
device = 'cuda:0'
model_dir = '/diplomacy_cicero/fairdiplomacy/AMR/amrlib/amrlib/data/model_parse_xfm/checkpoint-9920/'
self.inference = Inference(model_dir, batch_size=batch_size, num_beams=num_beams, device=device)


while not self.game.is_game_done:
while not (self.game.is_game_done or self.game.get_current_phase() == "S1915M"):
self.phase_start_time = time.time()
self.dipcc_current_phase = self.game.get_current_phase()

Expand All @@ -185,17 +194,18 @@ async def play_mila(
await self.update_press_dipcc_game(power_name)
# reply/gen new message
msg = self.generate_message(power_name)
print(f'msg from cicero to dipcc {msg}')

if msg is not None:
draw_token_message = self.is_draw_token_message(msg,power_name)
proposal_response = self.check_PRP(msg,power_name)

# send message in dipcc and Mila
if msg is not None and not proposal_response and not draw_token_message:
if msg is not None and not proposal_response and not draw_token_message and msg['recipient'] in self.game.powers:
recipient_power = msg['recipient']
power_pseudo = self.player.state.pseudo_orders_cache.maybe_get(
self.dipcc_game, self.player.power, True, True, recipient_power)

power_po = power_pseudo[self.dipcc_current_phase]
for power in power_po.keys():
if power == power_name:
Expand All @@ -211,21 +221,51 @@ async def play_mila(
self_pseudo_log = f'After I got the message from {recipient_power}, I intend to do: {self_po}'
await self.send_log(self_pseudo_log)

list_msg = self.to_daide_msg(msg)


await self.send_log(f'I expect {recipient_power} to do: {recp_po}')
await self.send_log(f'My (internal) response is: {msg["message"]}')

if len(list_msg)>0:
for daide_msg in list_msg:
await self.send_log(f'My external DAIDE response is: {daide_msg["message"]}')
# keep track of intent that we talked to each recipient
self.set_comm_intent(recipient_power, power_po)

if game_type==0:
list_msg = self.to_daide_msg(msg)
if len(list_msg)>0:
for daide_msg in list_msg:
await self.send_log(f'My external DAIDE response is: {daide_msg["message"]}')
self.send_message(msg, 'dipcc')
else:
await self.send_log(f'No valid DIADE found / Attempt to send repeated FCT/PRP messages')

for msg in list_msg:
self.send_message(msg, 'mila')

elif game_type==1:
list_msg = self.to_daide_msg(msg)
self.send_message(msg, 'dipcc')
else:
await self.send_log(f'No valid DIADE found / Attempt to send repeated FCT/PRP messages')
self.send_message(msg, 'mila')
for daide_msg in list_msg:
await self.send_log(f'My external DAIDE response is: {daide_msg["message"]}')
else:
await self.send_log(f'No valid DIADE found / Attempt to send repeated FCT/PRP messages')

for msg in list_msg:
for msg in list_msg:
self.send_message(msg, 'mila')

elif game_type==2:
self.send_message(msg, 'dipcc')
self.send_message(msg, 'mila')

if 'deceptive' in msg:
await self.send_log(msg['deceptive'])
print(f'Cicero logs if message is deceptive: {msg["deceptive"]}')

# for daide_msg in list_msg:
# await self.send_log(f'My DAIDE response is: {daide_msg["message"]}')
# else:
# await self.send_log(f'No valid DIADE found / Attempt to send repeated FCT/PRP messages')


await asyncio.sleep(0.25)

Expand All @@ -234,9 +274,13 @@ async def play_mila(
print(f"Submit orders in {self.dipcc_current_phase}")
agent_orders = self.player.get_orders(self.dipcc_game)

# keep track of our final order
self.set_comm_intent('final', agent_orders)
await self.send_log(f'A record of intents in {self.dipcc_current_phase}: {self.get_comm_intent()}')

# set order in Mila
self.game.set_orders(power_name=power_name, orders=agent_orders, wait=False)

# wait until the phase changed
print(f"wait until {self.dipcc_current_phase} is done", end=" ")
while not self.has_phase_changed():
Expand All @@ -256,7 +300,15 @@ async def play_mila(
)
file.write("\n")

def reset_comm_intent(self):
self.last_comm_intent={'RUSSIA':None,'TURKEY':None,'ITALY':None,'ENGLAND':None,'FRANCE':None,'GERMANY':None,'AUSTRIA':None,'final':None}

def get_comm_intent(self):
return self.last_comm_intent

def set_comm_intent(self, recipient, pseudo_orders):
self.last_comm_intent[recipient] = pseudo_orders

def check_PRP(self,msg,power_name):
phase_messages = self.get_messages(
messages=self.game.messages, power=power_name
Expand Down Expand Up @@ -300,6 +352,13 @@ def is_draw_token_message(self, msg ,power_name):
if UNDRAW_VOTE_TOKEN in msg['message']:
self.game.powers[power_name].vote = strings.NO
return True
if DATASET_DRAW_MESSAGE in msg['message']:
self.game.powers[power_name].vote = strings.YES
return True
if DATASET_NODRAW_MESSAGE in msg['message']:
self.game.powers[power_name].vote = strings.YES
return True

return False

def to_daide_msg(self, msg: MessageDict):
Expand Down Expand Up @@ -395,27 +454,30 @@ def psudo_code_gene(self,current_phase_code,message,power_dict,af_dict):
for country in current_phase_code.keys():
if country == message["sender"]:
#FCT for sender
has_FCT_order = True
# has_FCT_order = True
for i in current_phase_code[country]:
sen_length = len(i)
if sen_length == 11:
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[0]]+' '+i[2:5]+') MTO '+i[8:11]+'))'
has_FCT_order = True
elif sen_length == 7:
if i[6] == 'H':
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[0]]+' '+i[2:5]+') HLD))'
elif i[6] == 'B':
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[0]]+' '+i[2:5]+') BLD))'
elif i[6] == 'R':
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[0]]+' '+i[2:5]+') REM))'
has_FCT_order = True
elif sen_length == 19:
if i[6] =='S':
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[0]]+' '+i[2:5]+') SUP ('+power_dict[country]+' '+af_dict[i[8]]+' '+i[10:13]+') MTO '+i[16:19]+'))'
elif i[6] == 'C':
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[0]]+' '+i[2:5]+') CVY ('+power_dict[country]+' '+af_dict[i[8]]+' '+i[10:13]+') CTO '+i[16:19]+'))'
string1 += ' (XDO (('+power_dict[country]+' '+af_dict[i[8]]+' '+i[10:13]+') CTO '+i[16:19]+' VIA ('+i[2:5]+')))'
has_FCT_order = True
else:
#PRP for recipient
has_PRP_order = True
# has_PRP_order = True
for i in current_phase_code[country]:
sen_length = len(i)
if sen_length == 11:
Expand Down Expand Up @@ -486,6 +548,7 @@ def init_phase(self):
self.sent_FCT = {'RUSSIA':set(),'TURKEY':set(),'ITALY':set(),'ENGLAND':set(),'FRANCE':set(),'GERMANY':set(),'AUSTRIA':set()}
self.sent_PRP = {'RUSSIA':set(),'TURKEY':set(),'ITALY':set(),'ENGLAND':set(),'FRANCE':set(),'GERMANY':set(),'AUSTRIA':set()}
self.last_PRP_review_timestamp = {'RUSSIA':0,'TURKEY':0,'ITALY':0,'ENGLAND':0,'FRANCE':0,'GERMANY':0,'AUSTRIA':0}
self.reset_comm_intent()

def has_phase_changed(self)->bool:
"""
Expand Down Expand Up @@ -565,6 +628,7 @@ async def update_press_dipcc_game(self, power_name: POWERS):

# update message in dipcc game
for timesent, message in phase_messages.items():
print(f'message from mila to dipcc {message}')
if int(str(timesent)[0:10]) > int(str(self.last_received_message_time)[0:10]):

dipcc_timesent = Timestamp.from_seconds(timesent * 1e-6)
Expand Down Expand Up @@ -776,8 +840,10 @@ def update_past_phase(self, dipcc_game: Game, phase: str, power_name: str):
# and don't add it to the dipcc game.
# If it has at least one part that contains anything other than three upper letters,
# then just keep message body as original
if message.recipient == 'GLOBAL':
if message.recipient not in self.game.powers:
continue
# print(f'load message from mila to dipcc {message}')

if is_daide(message.message):
pre_processed = pre_process(message.message)
generated_English = gen_English(pre_processed, message.recipient, message.sender)
Expand Down Expand Up @@ -840,6 +906,12 @@ def main() -> None:
required=True,
help="power name",
)
parser.add_argument(
"--game_type",
type=int,
default=0,
help="0: AI-only game, 1: Human and AI game, 2: Human-only game",
)
# parser.add_argument(
# "--agent",
# type=Path,
Expand All @@ -848,10 +920,10 @@ def main() -> None:
# help="path to prototxt with agent's configurations (default: %(default)s)",
# )
parser.add_argument(
"--daide",
type=bool,
default= True,
help="Is Cicero a daide speaker or no?",
"--deceptive",
default= False,
action="store_true",
help="Is Cicero being deceptive? -- removing PO correspondence filter from message module?",
)
parser.add_argument(
"--outdir", type=Path, help="output directory for game json to be stored"
Expand All @@ -862,23 +934,25 @@ def main() -> None:
port: int = args.port
game_id: str = args.game_id
power: str = args.power
daide: bool = args.daide
deceptive: bool = args.deceptive
outdir: Optional[Path] = args.outdir
game_type : int = args.game_type

print(f"settings:")
print(f"host: {host}, port: {port}, game_id: {game_id}, power: {power}")

if outdir is not None and not outdir.is_dir():
outdir.mkdir(parents=True, exist_ok=True)

mila = milaWrapper(is_daide=daide)
mila = milaWrapper(is_deceptive=deceptive)

asyncio.run(
mila.play_mila(
hostname=host,
port=port,
game_id=game_id,
power_name=power,
game_type=game_type,
gamedir=outdir,
)
)
Expand Down
15 changes: 10 additions & 5 deletions parlai_diplomacy/utils/game2seq/format_helpers/message_editing.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,8 @@ def _corresponds_to_pseudo_orders(
- game: game object
- pseudo_orders: Joint action used to condition the dialogue model
"""
if self.pseudo_orders_correspondence_threshold is None:
return True, {}

pseudo_orders_correspondence_threshold = -5e-3

if not [m for p in game.get_all_phases() for m in p.messages.values()]:
# If there are no messages so far this game, bail
# This is because the "before" state will assume a no-press game which
Expand Down Expand Up @@ -635,11 +634,17 @@ def get_pseudo_logprob(game):
"before_prob": before_prob,
"after_prob": after_prob,
"diff": diff,
"thresh": self.pseudo_orders_correspondence_threshold,
"thresh": pseudo_orders_correspondence_threshold,
"pseudo_orders": sender_pseudo,
}

corresponds_to_pseudo = diff >= self.pseudo_orders_correspondence_threshold
corresponds_to_pseudo = diff >= pseudo_orders_correspondence_threshold
print(f'place one in message_editing.py message: {msg} and deceptive info {extra_info}')
print(f'self variable {self.pseudo_orders_correspondence_threshold}')
if self.pseudo_orders_correspondence_threshold == -1.0:
print(f'place two in message_editing.py message: {msg} and deceptive info {extra_info}')
return True, extra_info

return corresponds_to_pseudo, extra_info

def _edit_newlines(self, msg_txt: str) -> str:
Expand Down
Loading

0 comments on commit 3aa5a9e

Please sign in to comment.