Skip to content

Commit

Permalink
Refactor game sync results to use SQS
Browse files Browse the repository at this point in the history
  • Loading branch information
Kataiser committed Jan 27, 2025
1 parent fc7d1e0 commit 22ee44d
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 54 deletions.
74 changes: 61 additions & 13 deletions db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import atexit
import dataclasses
import decimal
import enum
import json
import os
from operator import itemgetter
from typing import Union, Any
Expand All @@ -13,6 +16,7 @@
class Table:
def __init__(self, table_name: str, primary_key: str):
self.table_name = table_name
self.table_full_name = f'CelesteTAS-Improvement-Tracker_{self.table_name}'
self.primary_key = primary_key
self.caching = False
self.cache = {}
Expand All @@ -23,12 +27,12 @@ def get(self, key: Union[str, int], consistent_read: bool = True, keep_primary_k

key_type = 'S' if isinstance(key, str) else 'N'
actual_consistent_read = False if always_inconsistent_read else consistent_read
item = client.get_item(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}', Key={self.primary_key: {key_type: str(key)}}, ConsistentRead=actual_consistent_read)
item = dynamodb_client.get_item(TableName=self.table_full_name, Key={self.primary_key: {key_type: str(key)}}, ConsistentRead=actual_consistent_read)

if 'Item' in item:
item_deserialized = deserializer.deserialize({'M': item['Item']})
else:
raise DBKeyError(f"'{key}' not found in table 'CelesteTAS-Improvement-Tracker_{self.table_name}'")
raise DBKeyError(f"'{key}' not found in table '{self.table_full_name}'")

if '_value' in item_deserialized:
result = item_deserialized['_value']
Expand Down Expand Up @@ -58,7 +62,7 @@ def set(self, key: Union[str, int], value: Any, get_previous: bool = False) -> A
value = {self.primary_key: key, '_value': value}

return_values = 'ALL_OLD' if get_previous else 'NONE'
response = client.put_item(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}', Item=serializer.serialize(value)['M'], ReturnValues=return_values)
response = dynamodb_client.put_item(TableName=self.table_full_name, Item=serializer.serialize(value)['M'], ReturnValues=return_values)

if added_primary:
del value[self.primary_key]
Expand All @@ -73,7 +77,7 @@ def set(self, key: Union[str, int], value: Any, get_previous: bool = False) -> A

def get_all(self, consistent_read: bool = True) -> list:
actual_consistent_read = False if always_inconsistent_read else consistent_read
items = client.scan(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}', ConsistentRead=actual_consistent_read)
items = dynamodb_client.scan(TableName=self.table_full_name, ConsistentRead=actual_consistent_read)
return [deserializer.deserialize({'M': item}) for item in items['Items']]

def dict(self, consistent_read: bool = True) -> dict:
Expand All @@ -92,16 +96,16 @@ def delete_item(self, key: Union[str, int]):
return

key_type = 'S' if isinstance(key, str) else 'N'
client.delete_item(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}', Key={self.primary_key: {key_type: str(key)}})
dynamodb_client.delete_item(TableName=self.table_full_name, Key={self.primary_key: {key_type: str(key)}})

def metadata(self) -> dict:
return client.describe_table(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}')
return dynamodb_client.describe_table(TableName=self.table_full_name)

def size(self, consistent_read: bool = True) -> int:
if consistent_read:
return client.scan(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}', Select='COUNT', ConsistentRead=True)['Count']
return dynamodb_client.scan(TableName=self.table_full_name, Select='COUNT', ConsistentRead=True)['Count']
else:
return client.describe_table(TableName=f'CelesteTAS-Improvement-Tracker_{self.table_name}')['Table']['ItemCount']
return dynamodb_client.describe_table(TableName=self.table_full_name)['Table']['ItemCount']

def enable_cache(self):
self.caching = True
Expand Down Expand Up @@ -186,6 +190,49 @@ def add_project_key(key: str, value: Any):
print(f"Added `{key}: {value}` to {len(projects_)} projects, be sure to update command_register_project")


class SyncResultType(enum.StrEnum):
NORMAL = enum.auto()
MAINGAME_COMMIT = enum.auto()
AUTO_DISABLE = enum.auto()
REPORTED_ERROR = enum.auto()


@dataclasses.dataclass
class SyncResult:
type: SyncResultType
data: dict
receipt_handle: str

def __str__(self) -> str:
return f"SyncResult type={str(self.type).upper()} data={self.data}"


def send_sync_result(result_type: SyncResultType, data: dict):
payload = {'type': str(result_type), 'data': data}
sqs_client.send_message(QueueUrl=sqs_queue_url, MessageBody=ujson.dumps(payload, ensure_ascii=False), MessageGroupId=str(result_type))


def get_sync_results() -> list[SyncResult]:
results = []
response = sqs_client.receive_message(QueueUrl=sqs_queue_url, MaxNumberOfMessages=10)

if 'Messages' not in response and response['ResponseMetadata']['HTTPStatusCode'] == 200:
return results

for message in response['Messages']:
body = json.loads(message['Body'])
results.append(SyncResult(type=SyncResultType(body['type']),
data=body['data'],
receipt_handle=message['ReceiptHandle']))

return results


def delete_sync_result(sync_result: SyncResult):
sqs_client.delete_message(QueueUrl=sqs_queue_url, ReceiptHandle=sync_result.receipt_handle)
del sync_result


class DBKeyError(Exception):
pass

Expand All @@ -196,16 +243,17 @@ class DBKeyError(Exception):
project_logs = Table('project_logs', 'project_id')
sheet_writes = Table('sheet_writes', 'timestamp')
logs = Table('logs', 'time')
sync_results = Table('sync_results', 'project_id')
misc = Table('misc', 'key')
contributors = Table('contributors', 'project_id')
sid_caches = Table('sid_caches', 'project_id')
tokens = Table('tokens', 'installation_owner')
projects = Projects('projects', 'project_id')
path_caches = PathCaches('path_caches', 'project_id')

client = boto3.client('dynamodb')
atexit.register(client.close)
dynamodb_client = boto3.client('dynamodb')
sqs_client = boto3.client('sqs')
sqs_queue_url = sqs_client.get_queue_url(QueueName='CelesteTAS-Improvement-Tracker_sync_results.fifo')['QueueUrl']
atexit.register(dynamodb_client.close)
serializer = TypeSerializer()
deserializer = TypeDeserializer()
always_inconsistent_read = False
Expand Down Expand Up @@ -240,7 +288,7 @@ class DBKeyError(Exception):
project_logs.set(int(project_log_name.removesuffix('.json')), project_log_loaded)

print(project_logs.get(970380662907482142))
print(client.describe_table(TableName='CelesteTAS-Improvement-Tracker_installations'))
print(dynamodb_client.describe_table(TableName='CelesteTAS-Improvement-Tracker_installations'))
print(installations.metadata())

with open('sync\\installations.json', 'r', encoding='UTF8') as installations_json:
Expand Down Expand Up @@ -292,4 +340,4 @@ class DBKeyError(Exception):
sid_caches.set(int(project_id), sid_caches_loaded[project_id])

print(sid_caches.get(1180581916529922188))
client.close()
dynamodb_client.close()
9 changes: 5 additions & 4 deletions game_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def sync_test(project_id: int, force: bool):
if asserts_added:
log.info(f"Added SID assertions to {len(asserts_added)} file{plural(asserts_added)}: {asserts_added}")

log.info(f"Previous desyncs: {previous_desyncs}")
game_process = wait_for_game_load(mods_to_load, project['name'])

for tas_filename in path_cache:
Expand Down Expand Up @@ -377,7 +378,7 @@ def sync_test(project_id: int, force: bool):

disabled_text = consider_disabling_after_inactivity(project, clone_time, False)
db.projects.set(project_id, project)
db.sync_results.set(project_id, {'report_text': report_text, 'disabled_text': disabled_text, 'log': report_log, 'crash_logs': crash_logs_data})
db.send_sync_result(db.SyncResultType.NORMAL, {'project_id': project_id, 'report_text': report_text, 'disabled_text': disabled_text, 'log': report_log, 'crash_logs': crash_logs_data})
log.info("Wrote sync result to DB")

# commit updated fullgame files
Expand All @@ -403,7 +404,7 @@ def sync_test(project_id: int, force: bool):
log.info(f"Successfully committed: {commit_url}")

if project_is_maingame:
db.sync_results.set(int(time.time()), {'maingame_message': f"Committed `{commit_message}` <{commit_url}>"})
db.send_sync_result(db.SyncResultType.MAINGAME_COMMIT, {'maingame_message': f"Committed `{commit_message}` <{commit_url}>"})

log.info(f"Sync check time: {format_elapsed_time(start_time)}")

Expand Down Expand Up @@ -661,7 +662,7 @@ def consider_disabling_after_inactivity(project: dict, reference_time: Union[int

if from_abandoned:
db.projects.set(project['project_id'], project)
db.sync_results.set(project['project_id'], {'report_text': None, 'disabled_text': disabled_text})
db.send_sync_result(db.SyncResultType.AUTO_DISABLE, {'project_id': project['project_id'], 'disabled_text': disabled_text})
else:
# don't need to return projects since it's mutable
return disabled_text
Expand Down Expand Up @@ -735,7 +736,7 @@ def scaled_sleep(seconds: float):

def log_error(message: Optional[str] = None):
error = utils.log_error(message)
db.sync_results.set(int(time.time()), {'reported_error': True, 'error': error[-1950:]})
db.send_sync_result(db.SyncResultType.REPORTED_ERROR, {'time': int(time.time()), 'error': error[-1950:]})


log: Union[logging.Logger, utils.LogPlaceholder] = utils.LogPlaceholder()
Expand Down
65 changes: 28 additions & 37 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,64 +454,55 @@ def convert_line_endings(tas: bytes, old_tas: Optional[bytes]) -> bytes:

@tasks.loop(minutes=1)
async def handle_game_sync_results():
sync_results_found = db.sync_results.get_all()
sync_results = db.get_sync_results()

if not sync_results_found:
if not sync_results:
return

global client

for sync_result in sync_results_found:
project_id = int(sync_result['project_id'])
for sync_result in sync_results:
log.info(f"Handling {str(sync_result)}")
db.delete_sync_result(sync_result)

try:
if sync_result.type in (db.SyncResultType.NORMAL, db.SyncResultType.AUTO_DISABLE):
project_id = sync_result.data['project_id']
project = db.projects.get(project_id)
except db.DBKeyError:
if 'reported_error' in sync_result:
log.info(f"Reporting sync check error: {sync_result}")
await (await utils.user_from_id(client, admin_user_id)).send(f"<t:{project_id}:R>\n```\n{sync_result['error']}```")
else:
log.info(f"Reporting updated maingame time: \"{sync_result['maingame_message']}\"")
await client.get_channel(1323811411226263654).send(sync_result['maingame_message'])

db.sync_results.delete_item(project_id)
continue
project_name = project['name']
improvements_channel = client.get_channel(project_id)
await edit_pin(improvements_channel)

project_name = project['name']
log.info(f"Handling game sync result for project {project_name}")
report_text = sync_result['report_text']
disabled_text = sync_result['disabled_text']
improvements_channel = client.get_channel(project_id)
await edit_pin(improvements_channel)

if report_text:
if sync_result['log']:
match sync_result.type:
case db.SyncResultType.NORMAL:
sync_check_time = project['last_run_validation']
files = [discord.File(io.BytesIO(sync_result['log'].value), filename=f'game_sync_{project_name}_{sync_check_time}.log.gz')]
files = [discord.File(io.BytesIO(sync_result.data['log'].value), filename=f'game_sync_{project_name}_{sync_check_time}.log.gz')]

for crash_log_name in sync_result['crash_logs']:
crash_log_data = sync_result['crash_logs'][crash_log_name].value
for crash_log_name in sync_result.data['crash_logs']:
crash_log_data = sync_result.data['crash_logs'][crash_log_name].value
files.append(discord.File(io.BytesIO(crash_log_data), filename=crash_log_name))

await improvements_channel.send(report_text, files=files)
else:
await improvements_channel.send(report_text)
await improvements_channel.send(sync_result.data['report_text'], files=files)

case db.SyncResultType.AUTO_DISABLE:
await improvements_channel.send(sync_result.data['disabled_text'])

case db.SyncResultType.REPORTED_ERROR:
await (await utils.user_from_id(client, admin_user_id)).send(f"<t:{sync_result.data['time']}:R>\n```\n{sync_result.data['error']}```")

if disabled_text:
await improvements_channel.send(disabled_text)
case db.SyncResultType.MAINGAME_COMMIT:
await client.get_channel(1323811411226263654).send(sync_result.data['maingame_message'])

db.sync_results.delete_item(project_id)
db.delete_sync_result(sync_result)

db.misc.set('last_game_sync_result_time', int(time.time()))


@tasks.loop(hours=2)
async def handle_no_game_sync_results():
if not db.sync_results.size():
time_since_last_game_sync_result = time.time() - float(db.misc.get('last_game_sync_result_time'))
time_since_last_game_sync_result = time.time() - float(db.misc.get('last_game_sync_result_time'))

if time_since_last_game_sync_result > 86400: # 24 hours
await (await utils.user_from_id(client, admin_user_id)).send(f"Warning: last sync check was {round(time_since_last_game_sync_result / 3600, 1)} hours ago")
if time_since_last_game_sync_result > 86400: # 24 hours
await (await utils.user_from_id(client, admin_user_id)).send(f"Warning: last sync check was {round(time_since_last_game_sync_result / 3600, 1)} hours ago")


@tasks.loop(seconds=30)
Expand Down

0 comments on commit 22ee44d

Please sign in to comment.