Skip to content

Commit

Permalink
updated model plots and README
Browse files Browse the repository at this point in the history
  • Loading branch information
merillium committed Mar 7, 2024
1 parent 8a795dd commit a31824e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ To download and preprocess data from the lichess.org open database, you can run
python3 download_and_preprocess.py --year 2015 --month 1 --source lichess-open-database
```

Warning: preprocessing will take a long time to complete for more recent data which can be well over 20 GB in size. You can use caffeinate to prevent your computer from going to sleep while the script is running.

```bash
caffeinate -is python3 download_and_preprocess.py --year 2015 --month 1 --source lichess-open-database
```


The `download_and_preprocess.py` script downloads the `.pgn.zst` file corresponding to the month and year specified, decompresses the `.pgn` file, and creates the `lichess_downloaded_games` directory to which both files are saved. Then the script preprocesses the `.pgn` file and extracts relevant features, creates the `lichess_player_data` directory, to which a `.csv` file is saved. By default, exploratory plots are generated, and then all raw files in the `lichess_downloaded_games` directory are deleted because they are typically large and not needed after preprocessing. (This process can be streamlined by directly reading from the decompressed `.pgn` file instead of first saving it)

### Model Description
Expand All @@ -17,13 +24,6 @@ This is a simple statistical model that flags players who have performed a certa
### Model Training
We define `N` as the number of players who have performed above some threshold, and the estimated number of cheaters as `X = 0.00 * N_open + 0.75 * N_closed + 1.00 * N_violation` where `N_open` is the number of players with open accounts, `N_closed` is the number of players with closed accounts, and `N_violation` is the number of players with a terms of service violation (where `N = N_open + N_closed + N_violation`), the metric used to evaluate the performance of the threshold is the `log(N+1) * X / N`. This is a simple metric intended to reward the model for `high accuracy = X / N` in detecting suspicious players without flagging too many players (observationally, if the threshold is too low, the accuracy will decrease faster than `log(N)`). Note that for a threshold that is too high and flags 0 players, the metric will be 0. This metric may be fine-tuned in the future, but is sufficient for a POC.

Below is an example of the threshold vs accuracy plot below for players in the 1400-1500 range for classical chess based on training data from the month of Jan 2015.

![sample threshold vs accuracy plot](images/sample_model_threshold.png)

### Assumptions
The model is built on the assumption that cheating is a rare occurrence in any data set on which the model is trained. There may be unexpected behavior if the training data is composed predomininantly of players who are cheating. The model will retain its default thresholds in the event that no players have shown any significant deviations from the mean expected performance in their rating bin.

### Sample code:
```python
import pandas as pd
Expand All @@ -39,6 +39,28 @@ model.save_model(f'{BASE_FILE_NAME}_model')
predictions = model.predict(train_data)
```

### Model Evaluation

When the model is fitted, there are accuracy metric vs threshold figures that are saved to the `model_plots` directory as json files. The figure object can be loaded from the json file, as shown in the example code snippet below:

```python
import json
import plotly.io as pio

f = open('model_plots/test_model_thresholds_classical_1400-1500.json')
data = json.load(f)

fig = pio.from_json(data)
fig.show()
```

Below is an example of the threshold vs accuracy plot below for players in the 1400-1500 range for classical chess based on training data from the month of Jan 2015.

![sample threshold vs accuracy plot](images/sample_model_threshold.png)

### Assumptions
The model is built on the assumption that cheating is a rare occurrence in any data set on which the model is trained. There may be unexpected behavior if the training data is composed predomininantly of players who are cheating. The model will retain its default thresholds in the event that no players have shown any significant deviations from the mean expected performance in their rating bin.

### Unit Tests
Currently working on unit tests, which can be run with the following command:
```make test```, or if you want to run test files individually ```PYTHONPATH=. pytest tests/test_model.py```
Expand Down
3 changes: 3 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _set_thresholds(self, train_data, generate_plots):
best_train_metric = 0.00

train_accuracy_list = []
train_metric_list = []
train_threshold_list = []
train_number_of_flagged_players = []

Expand Down Expand Up @@ -121,6 +122,7 @@ def _set_thresholds(self, train_data, generate_plots):
## a threshold that flags 100 players with 0.50 accuracy
## is worse than a threshold that flags 20 players with 1.00 accuracy
train_metric = np.log(number_of_flagged_players + 1) * train_accuracy
train_metric_list.append(train_metric)

## update the best threshold
if train_metric > best_train_metric:
Expand All @@ -147,6 +149,7 @@ def _set_thresholds(self, train_data, generate_plots):
Folders.MODEL_PLOTS.value,
train_threshold_list,
train_accuracy_list,
train_metric_list,
train_number_of_flagged_players,
best_threshold,
time_control,
Expand Down
33 changes: 26 additions & 7 deletions model_plots.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import plotly.io as pio

def generate_model_threshold_plots(
base_file_name,
model_plots_folder,
train_threshold_list,
train_accuracy_list,
train_metric_list,
train_number_of_flagged_players,
best_threshold,
time_control,
Expand All @@ -18,7 +20,17 @@ def generate_model_threshold_plots(
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(
go.Scatter(
x=train_threshold_list, y=train_accuracy_list, name="Accuracy vs Threshold"
x=train_threshold_list,
y=train_metric_list,
name="Train Metric"
),
secondary_y=False,
)
fig.add_trace(
go.Scatter(
x=train_threshold_list,
y=train_accuracy_list,
name="Accuracy"
),
secondary_y=False,
)
Expand All @@ -38,14 +50,21 @@ def generate_model_threshold_plots(
annotation_text="Best Threshold",
)
fig.update_layout(
title=f"Accuracy vs Threshold for {time_control}: Rating Bin {rating_bin_key}",
title=f"Accuracy Metric vs Threshold for {time_control}: Rating Bin {rating_bin_key}",
xaxis_title="Threshold",
yaxis_title="Accuracy",
yaxis_title="Metric/Accuracy Values",
yaxis2_title="Number of Flagged Players",
yaxis_range=[0, 1],
)
if not os.path.exists(model_plots_folder):
os.mkdir(model_plots_folder)
fig.write_html(
f"{model_plots_folder}/{base_file_name}_model_thresholds_{time_control}_{rating_bin_key}.html"
)

# fig.write_html(
# f"{model_plots_folder}/{base_file_name}_model_thresholds_{time_control}_{rating_bin_key}.html"
# )

## we may want both htmls and jsons that can be used to directly import and access the figure objects
fig_json = fig.to_json()
model_plot_filename = f"{model_plots_folder}/{base_file_name}_model_thresholds_{time_control}_{rating_bin_key}.json"
with open(model_plot_filename, 'w') as f:
json.dump(fig_json, f)

0 comments on commit a31824e

Please sign in to comment.