From a31824eede2e7da1224e89bb0aa9ce83c9ad2736 Mon Sep 17 00:00:00 2001 From: merillium Date: Thu, 7 Mar 2024 16:09:54 -0500 Subject: [PATCH] updated model plots and README --- README.md | 36 +++++++++++++++++++++++++++++------- model.py | 3 +++ model_plots.py | 33 ++++++++++++++++++++++++++------- 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 8f04f72..e5fed26 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ To download and preprocess data from the lichess.org open database, you can run python3 download_and_preprocess.py --year 2015 --month 1 --source lichess-open-database ``` +Warning: preprocessing will take a long time to complete for more recent data which can be well over 20 GB in size. You can use caffeinate to prevent your computer from going to sleep while the script is running. + +```bash +caffeinate -is python3 download_and_preprocess.py --year 2015 --month 1 --source lichess-open-database +``` + + The `download_and_preprocess.py` script downloads the `.pgn.zst` file corresponding to the month and year specified, decompresses the `.pgn` file, and creates the `lichess_downloaded_games` directory to which both files are saved. Then the script preprocesses the `.pgn` file and extracts relevant features, creates the `lichess_player_data` directory, to which a `.csv` file is saved. By default, exploratory plots are generated, and then all raw files in the `lichess_downloaded_games` directory are deleted because they are typically large and not needed after preprocessing. (This process can be streamlined by directly reading from the decompressed `.pgn` file instead of first saving it) ### Model Description @@ -17,13 +24,6 @@ This is a simple statistical model that flags players who have performed a certa ### Model Training We define `N` as the number of players who have performed above some threshold, and the estimated number of cheaters as `X = 0.00 * N_open + 0.75 * N_closed + 1.00 * N_violation` where `N_open` is the number of players with open accounts, `N_closed` is the number of players with closed accounts, and `N_violation` is the number of players with a terms of service violation (where `N = N_open + N_closed + N_violation`), the metric used to evaluate the performance of the threshold is the `log(N+1) * X / N`. This is a simple metric intended to reward the model for `high accuracy = X / N` in detecting suspicious players without flagging too many players (observationally, if the threshold is too low, the accuracy will decrease faster than `log(N)`). Note that for a threshold that is too high and flags 0 players, the metric will be 0. This metric may be fine-tuned in the future, but is sufficient for a POC. -Below is an example of the threshold vs accuracy plot below for players in the 1400-1500 range for classical chess based on training data from the month of Jan 2015. - -![sample threshold vs accuracy plot](images/sample_model_threshold.png) - -### Assumptions -The model is built on the assumption that cheating is a rare occurrence in any data set on which the model is trained. There may be unexpected behavior if the training data is composed predomininantly of players who are cheating. The model will retain its default thresholds in the event that no players have shown any significant deviations from the mean expected performance in their rating bin. - ### Sample code: ```python import pandas as pd @@ -39,6 +39,28 @@ model.save_model(f'{BASE_FILE_NAME}_model') predictions = model.predict(train_data) ``` +### Model Evaluation + +When the model is fitted, there are accuracy metric vs threshold figures that are saved to the `model_plots` directory as json files. The figure object can be loaded from the json file, as shown in the example code snippet below: + +```python +import json +import plotly.io as pio + +f = open('model_plots/test_model_thresholds_classical_1400-1500.json') +data = json.load(f) + +fig = pio.from_json(data) +fig.show() +``` + +Below is an example of the threshold vs accuracy plot below for players in the 1400-1500 range for classical chess based on training data from the month of Jan 2015. + +![sample threshold vs accuracy plot](images/sample_model_threshold.png) + +### Assumptions +The model is built on the assumption that cheating is a rare occurrence in any data set on which the model is trained. There may be unexpected behavior if the training data is composed predomininantly of players who are cheating. The model will retain its default thresholds in the event that no players have shown any significant deviations from the mean expected performance in their rating bin. + ### Unit Tests Currently working on unit tests, which can be run with the following command: ```make test```, or if you want to run test files individually ```PYTHONPATH=. pytest tests/test_model.py``` diff --git a/model.py b/model.py index dc4c3a0..41331f7 100644 --- a/model.py +++ b/model.py @@ -79,6 +79,7 @@ def _set_thresholds(self, train_data, generate_plots): best_train_metric = 0.00 train_accuracy_list = [] + train_metric_list = [] train_threshold_list = [] train_number_of_flagged_players = [] @@ -121,6 +122,7 @@ def _set_thresholds(self, train_data, generate_plots): ## a threshold that flags 100 players with 0.50 accuracy ## is worse than a threshold that flags 20 players with 1.00 accuracy train_metric = np.log(number_of_flagged_players + 1) * train_accuracy + train_metric_list.append(train_metric) ## update the best threshold if train_metric > best_train_metric: @@ -147,6 +149,7 @@ def _set_thresholds(self, train_data, generate_plots): Folders.MODEL_PLOTS.value, train_threshold_list, train_accuracy_list, + train_metric_list, train_number_of_flagged_players, best_threshold, time_control, diff --git a/model_plots.py b/model_plots.py index 712bd86..10e8a56 100644 --- a/model_plots.py +++ b/model_plots.py @@ -1,13 +1,15 @@ +import json import os import plotly.graph_objects as go from plotly.subplots import make_subplots - +import plotly.io as pio def generate_model_threshold_plots( base_file_name, model_plots_folder, train_threshold_list, train_accuracy_list, + train_metric_list, train_number_of_flagged_players, best_threshold, time_control, @@ -18,7 +20,17 @@ def generate_model_threshold_plots( fig = make_subplots(specs=[[{"secondary_y": True}]]) fig.add_trace( go.Scatter( - x=train_threshold_list, y=train_accuracy_list, name="Accuracy vs Threshold" + x=train_threshold_list, + y=train_metric_list, + name="Train Metric" + ), + secondary_y=False, + ) + fig.add_trace( + go.Scatter( + x=train_threshold_list, + y=train_accuracy_list, + name="Accuracy" ), secondary_y=False, ) @@ -38,14 +50,21 @@ def generate_model_threshold_plots( annotation_text="Best Threshold", ) fig.update_layout( - title=f"Accuracy vs Threshold for {time_control}: Rating Bin {rating_bin_key}", + title=f"Accuracy Metric vs Threshold for {time_control}: Rating Bin {rating_bin_key}", xaxis_title="Threshold", - yaxis_title="Accuracy", + yaxis_title="Metric/Accuracy Values", yaxis2_title="Number of Flagged Players", yaxis_range=[0, 1], ) if not os.path.exists(model_plots_folder): os.mkdir(model_plots_folder) - fig.write_html( - f"{model_plots_folder}/{base_file_name}_model_thresholds_{time_control}_{rating_bin_key}.html" - ) + + # fig.write_html( + # f"{model_plots_folder}/{base_file_name}_model_thresholds_{time_control}_{rating_bin_key}.html" + # ) + + ## we may want both htmls and jsons that can be used to directly import and access the figure objects + fig_json = fig.to_json() + model_plot_filename = f"{model_plots_folder}/{base_file_name}_model_thresholds_{time_control}_{rating_bin_key}.json" + with open(model_plot_filename, 'w') as f: + json.dump(fig_json, f) \ No newline at end of file