Skip to content

Commit

Permalink
added test for model predict method
Browse files Browse the repository at this point in the history
  • Loading branch information
merillium committed Feb 5, 2024
1 parent b2fe31b commit 64db3e0
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 7 deletions.
6 changes: 0 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def _set_thresholds(self, train_data, generate_plots):
]["player"].tolist()

number_of_flagged_players = len(all_flagged_players)
print(
f"number of flagged players in {rating_bin_key} = {number_of_flagged_players}"
)

## break if threshold is large enough to filter out all players
if number_of_flagged_players == 0:
Expand All @@ -94,9 +91,6 @@ def _set_thresholds(self, train_data, generate_plots):
for player in all_flagged_players:
self._player_account_handler.update_player_account_status(player)

print("updated account statuses:")
print(self._player_account_handler._account_statuses)

## get the account status for each player
train_predictions = [
self._player_account_handler._account_statuses.get(player)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ python-lichess==0.10 # Client for lichess.org API
mock==5.1.0
pylint==3.0.3
pytest==7.0.1
pytest-dependency==0.6.0

debugpy # Required for debugging.
79 changes: 78 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import unittest
import pandas as pd
from pandas.testing import assert_frame_equal
import pytest
from unittest import mock
from model import PlayerAnomalyDetectionModel
Expand Down Expand Up @@ -38,7 +39,40 @@ def get_sample_train_data():
return sample_train_data


# fixture for sample test data
# players 7-12 where players 10, 11, 12 perform above expected
@pytest.fixture(scope="class")
def get_sample_test_data():
sample_test_data = pd.DataFrame(
{
"player": [f"test_player{i}" for i in range(7, 6 + 7)] * 2,
"time_control": ["blitz"] * 6 + ["bullet"] * 6,
"number_of_games": [100] * 12,
"mean_perf_diff": [0.04, 0.06, 0.15, 0.161, 0.17, 0.25]
+ [0.15, 0.15, 0.16, 0.171, 0.18, 0.25],
"std_perf_diff": [0.005] * 12,
"mean_rating": [1510, 1520, 1530, 1540, 1550, 1560]
+ [1510, 1520, 1530, 1540, 1550, 1560],
"median_rating": [1510, 1520, 1530, 1540, 1550, 1560]
+ [1510, 1520, 1530, 1540, 1550, 1560],
"std_rating": [10] * 12,
"mean_opponent_rating": [1510, 1520, 1530, 1540, 1550, 1560]
+ [1510, 1520, 1530, 1540, 1550, 1560],
"std_opponent_rating": [10] * 12,
"mean_rating_gain": [1.00, -1.00, 1.00, -1.00, 1.00, -1.00]
+ [1.00, -1.00, 1.00, -1.00, 1.00, -1.00],
"std_rating_gain": [0.01] * 12,
"proportion_increment_games": [1.00, 1.00, 1.00, 1.00, 1.00, 1.00]
+ [0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
"rating_bin": [1500] * 6 + [1600] * 6,
}
)

return sample_test_data


@pytest.mark.usefixtures("get_sample_train_data", "build_training_data")
@pytest.mark.usefixtures("get_sample_test_data", "build_training_data")
class TestPlayerAnomalyDetectionModel(unittest.TestCase):
@mock.patch(
"player_account_handler.PlayerAccountHandler.update_player_account_status"
Expand All @@ -51,6 +85,10 @@ def setUp(self, mock_update_player_account_status):
def build_training_data(self, get_sample_train_data):
self.sample_train_data = get_sample_train_data

@pytest.fixture(autouse=True)
def build_test_data(self, get_sample_test_data):
self.sample_test_data = get_sample_test_data

def test_fit(self):
## this is a workaround to avoid calling get_player_account_status
self.model._player_account_handler._account_statuses = {
Expand All @@ -71,7 +109,46 @@ def test_fit(self):
assert expected_thresholds == self.model._thresholds

def test_predict(self):
pass
self.model._player_account_handler._account_statuses = {
"test_player7": "open",
"test_player8": "open",
"test_player9": "open",
"test_player10": "closed",
"test_player11": "tosViolation",
"test_player12": "tosViolation",
}
self.model.fit(self.sample_test_data, generate_plots=False)
expected_predictions = self.sample_test_data.copy()
expected_predictions["is_anomaly"] = [
False,
False,
False,
True,
True,
True,
False,
False,
False,
True,
True,
True,
]
expected_predictions["account_status"] = [
"open",
"open",
"open",
"closed",
"tosViolation",
"tosViolation",
"open",
"open",
"open",
"closed",
"tosViolation",
"tosViolation",
]
test_predictions = self.model.predict(self.sample_test_data)
assert_frame_equal(expected_predictions, test_predictions)

def test_save_model(self):
pass
Expand Down

0 comments on commit 64db3e0

Please sign in to comment.