diff --git a/baseline.ipynb b/baseline.ipynb new file mode 100644 index 0000000..c587ffc --- /dev/null +++ b/baseline.ipynb @@ -0,0 +1,2846 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "player_count_check = pd.read_csv('G:/datasets/csgo/match-map-unique/train/match-58-de_inferno-17-687.csv')\n", + "player_count = player_count_check['SteamId'].nunique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "player_count" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from csgo_wp.data_transform import CSGODataset, transform_data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transforming raw data...\n" + ] + }, + { + "ename": "ValueError", + "evalue": "only one element tensors can be converted to Python scalars", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mwarnings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfilterwarnings\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'ignore'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mdataset\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCSGODataset\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtransform_data\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32mg:\\git\\csgo-win-probability\\csgo_wp\\data_transform.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, folder, transform, dataset_split, rng_seed)\u001b[0m\n\u001b[0;32m 243\u001b[0m target = self.rounds[(self.rounds['MatchId'] == match_id)\n\u001b[0;32m 244\u001b[0m \u001b[1;33m&\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrounds\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'MapName'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mmap_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 245\u001b[1;33m & (self.rounds['RoundNum'] == round_num)]\n\u001b[0m\u001b[0;32m 246\u001b[0m \u001b[0mtarget\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'WinningSide'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'CT'\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 247\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtargets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtarget\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtransformed\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mValueError\u001b[0m: only one element tensors can be converted to Python scalars" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "dataset = CSGODataset(transform=transform_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('G:/datasets/csgo/train/data.pckl', 'wb') as f:\n", + " pickle.dump((dataset.data, dataset.targets), f)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "df = pd.read_csv('G:/datasets/csgo/match-map-unique/train/match-1071-de_dust2-1-278.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PlayerSteamId
MatchId
107110
\n", + "
" + ], + "text/plain": [ + " PlayerSteamId\n", + "MatchId \n", + "1071 10" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.groupby(['MatchId']).agg({'PlayerSteamId': 'nunique'})" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PlayerSteamId
Tick
289
619
949
1279
1609
......
90399
90729
91059
91389
91719
\n", + "

278 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " PlayerSteamId\n", + "Tick \n", + "28 9\n", + "61 9\n", + "94 9\n", + "127 9\n", + "160 9\n", + "... ...\n", + "9039 9\n", + "9072 9\n", + "9105 9\n", + "9138 9\n", + "9171 9\n", + "\n", + "[278 rows x 1 columns]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.groupby(['Tick']).agg({'PlayerSteamId': 'nunique'})" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([4], dtype=int64)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grouped = df.groupby(['Tick', 'Side'], as_index=False).agg({'PlayerSteamId': 'nunique'})\n", + "grouped[grouped['Side'] == 'T']['PlayerSteamId'].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "https://steamcommunity.com/profiles/76561198047402862/\n", + "Showed up in 140 ticks\n", + "https://steamcommunity.com/profiles/76561197997981170/\n", + "Showed up in 278 ticks\n", + "https://steamcommunity.com/profiles/76561197994395491/\n", + "Showed up in 278 ticks\n", + "https://steamcommunity.com/profiles/76561197978321481/\n", + "Showed up in 278 ticks\n", + "https://steamcommunity.com/profiles/76561198012987839/\n", + "Showed up in 138 ticks\n" + ] + } + ], + "source": [ + "for x in df[df['Side'] == 'T']['PlayerSteamId'].unique():\n", + " print(f'https://steamcommunity.com/profiles/{x}/')\n", + " print(f\"Showed up in {df[df['PlayerSteamId'] == x]['Tick'].nunique()} ticks\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2502" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2502" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "278*9" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "g:\\git\\csgo-win-probability\\csgo_wp\\data_transform.py:53: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " df['pos'] = df[['X', 'Y', 'Z']].values.tolist()\n" + ] + }, + { + "data": { + "text/plain": [ + "(tensor([[ 1., 50., 51., 2., 43., 1., 8., 3., 38., 46.],\n", + " [36., 1., 2., 35., 8., 36., 39., 36., 8., 2.],\n", + " [37., 2., 1., 36., 9., 37., 40., 37., 9., 3.],\n", + " [ 2., 49., 50., 1., 42., 2., 7., 2., 37., 45.],\n", + " [30., 8., 9., 29., 1., 30., 32., 30., 4., 4.],\n", + " [ 1., 50., 51., 2., 43., 1., 8., 3., 38., 46.],\n", + " [24., 53., 54., 23., 46., 24., 1., 22., 41., 49.],\n", + " [ 3., 50., 51., 2., 43., 3., 8., 1., 38., 46.],\n", + " [29., 8., 9., 28., 4., 29., 31., 29., 1., 5.],\n", + " [32., 2., 3., 31., 4., 32., 35., 32., 5., 1.],\n", + " [ 1., 1., 1., 0., 1., 0., 1., 1., 1., 0.],\n", + " [ 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.]]), tensor([1.]))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "1 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "2 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "3 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "4 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "5 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "6 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "7 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "8 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "9 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "10 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "11 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "12 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "13 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "14 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "15 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "16 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "17 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "18 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "19 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "20 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "21 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "22 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "23 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "24 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "25 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "26 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "27 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "28 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "29 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "30 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "31 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "32 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "33 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "34 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "35 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "36 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "37 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "38 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "39 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "40 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "41 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "42 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "43 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "44 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "45 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "46 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "47 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "48 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "49 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "50 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "51 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "52 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "53 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "54 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "55 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "56 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "57 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "58 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "59 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "60 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "61 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "62 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "63 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "64 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "65 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "66 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "67 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "68 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "69 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "70 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "71 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "72 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "73 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "74 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "75 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "76 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "77 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "78 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "79 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "80 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "81 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "82 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "83 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "84 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "85 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "86 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "87 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "88 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "89 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "90 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "91 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "92 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "93 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "94 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "95 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "96 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "97 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "98 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "99 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "100 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "101 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "102 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "103 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "104 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "105 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "106 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "107 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "108 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "109 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "110 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "111 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "112 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "113 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "114 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "115 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "116 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "117 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "118 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "119 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "120 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "121 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "122 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "123 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "124 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "125 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "126 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "127 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "128 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "129 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "130 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "131 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "132 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "133 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "134 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "135 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "136 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "137 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "138 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "139 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "140 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "141 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "142 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "143 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "144 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "145 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "146 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "147 torch.Size([24, 12, 10]) torch.Size([24, 1])\n", + "148 torch.Size([24, 12, 10]) torch.Size([24, 1])\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[0mstart\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;31m# check if dataset acts as expected\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mindex\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 14\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mstart\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\python37\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 361\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 362\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 363\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 364\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 365\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\python37\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 401\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 402\u001b[0m \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 403\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 404\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 405\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\python37\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\python37\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 42\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 45\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mg:\\git\\csgo-win-probability\\csgo_wp\\data_transform.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, sample_idx)\u001b[0m\n\u001b[0;32m 224\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmatchup\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatchups\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 225\u001b[0m \u001b[0mtick_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmatchup\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstem\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'-'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# last component\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 226\u001b[1;33m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 227\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_samples\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_samples\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mtick_count\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 228\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatchup_idx_by_sample_idx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_samples\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mg:\\git\\csgo-win-probability\\csgo_wp\\data_transform.py\u001b[0m in \u001b[0;36mtransform_data\u001b[1;34m(df, game_map)\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[1;31m# return torch.cat((t, t_2), dim=2).unsqueeze(1)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[0mn_samples\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 82\u001b[1;33m result = torch.cat([t.reshape(n_samples, 100),\n\u001b[0m\u001b[0;32m 83\u001b[0m t_2.reshape(n_samples, 20)],\n\u001b[0;32m 84\u001b[0m dim=1).view(n_samples, 12, 10)\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "import torch\n", + "import time\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "loader = torch.utils.data.DataLoader(dataset,\n", + " batch_size=24,\n", + " shuffle=False,\n", + " num_workers=0,\n", + " )\n", + "start = time.time()\n", + "# check if dataset acts as expected\n", + "for index, (data, target) in enumerate(loader):\n", + " print(index, data.shape, target.shape)\n", + "print(time.time() - start)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# del dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "player_count_check = pd.read_csv('G:/datasets/csgo/match-map-unique/train/match-96-de_dust2-11-147.csv')\n", + "player_count = player_count_check['PlayerSteamId'].nunique()\n", + "player_count" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "tick_count_check = pd.read_csv('G:/datasets/csgo/csgo_playerframes_dust2.csv',\n", + " names=['MatchId',\n", + " 'MapName',\n", + " 'RoundNum',\n", + " 'Tick',\n", + " 'Second',\n", + " 'PlayerId',\n", + " 'PlayerSteamId',\n", + " 'TeamId',\n", + " 'Side',\n", + " 'X',\n", + " 'Y',\n", + " 'Z',\n", + " 'ViewX',\n", + " 'ViewY',\n", + " 'AreaId',\n", + " 'Hp',\n", + " 'Armor',\n", + " 'IsAlive',\n", + " 'IsFlashed',\n", + " 'IsAirborne',\n", + " 'IsDucking',\n", + " 'IsScoped',\n", + " 'IsWalking',\n", + " 'EqValue',\n", + " 'HasHelmet',\n", + " 'HasDefuse',\n", + " 'DistToBombsiteA',\n", + " 'DistToBombsiteB',\n", + " 'Created',\n", + " 'Updated'],\n", + " usecols=['MatchId', 'MapName', 'RoundNum', 'Tick', 'PlayerSteamId'])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PlayerSteamId
MatchId
9610
\n", + "
" + ], + "text/plain": [ + " PlayerSteamId\n", + "MatchId \n", + "96 10" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tick_count_check[(tick_count_check['MatchId'] == 96)].groupby('MatchId').agg({'PlayerSteamId': 'nunique'})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Baseline" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
RoundNum
WinningSide
CT0.474391
T0.525609
\n", + "
" + ], + "text/plain": [ + " RoundNum\n", + "WinningSide \n", + "CT 0.474391\n", + "T 0.525609" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ideas: map-based, player count based, logreg, peterx's stuff\n", + "import pandas as pd\n", + "\n", + "map_data = pd.read_csv('G:/datasets/csgo/csgo_rounds_dust2.csv', usecols=['RoundNum', 'WinningSide'])\n", + "map_baseline = map_data.groupby('WinningSide').agg({'RoundNum': 'count'})\n", + "map_baseline = map_baseline / map_baseline.sum(axis=0)\n", + "map_baseline" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "print(roc_auc_score(map_data['WinningSide'] == 'CT', np.zeros(map_data.shape[0])))\n", + "del map_data" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MatchIdRoundNumSideTickIsAliveWinningSide
041CT1525751
141CT1529051
241CT1532351
341CT1535651
441CT1538951
.....................
2131902189222T30653011
2131903189222T30659611
2131904189222T30666211
2131905189222T30672811
2131906189222T30679411
\n", + "

2131907 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " MatchId RoundNum Side Tick IsAlive WinningSide\n", + "0 4 1 CT 15257 5 1\n", + "1 4 1 CT 15290 5 1\n", + "2 4 1 CT 15323 5 1\n", + "3 4 1 CT 15356 5 1\n", + "4 4 1 CT 15389 5 1\n", + "... ... ... ... ... ... ...\n", + "2131902 1892 22 T 306530 1 1\n", + "2131903 1892 22 T 306596 1 1\n", + "2131904 1892 22 T 306662 1 1\n", + "2131905 1892 22 T 306728 1 1\n", + "2131906 1892 22 T 306794 1 1\n", + "\n", + "[2131907 rows x 6 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# player count based\n", + "import pandas as pd\n", + "\n", + "player_data = pd.read_csv('G:/datasets/csgo/csgo_playerframes_dust2.csv',\n", + " names=['MatchId',\n", + " 'MapName',\n", + " 'RoundNum',\n", + " 'Tick',\n", + " 'Second',\n", + " 'PlayerId',\n", + " 'PlayerSteamId',\n", + " 'TeamId',\n", + " 'Side',\n", + " 'X',\n", + " 'Y',\n", + " 'Z',\n", + " 'ViewX',\n", + " 'ViewY',\n", + " 'AreaId',\n", + " 'Hp',\n", + " 'Armor',\n", + " 'IsAlive',\n", + " 'IsFlashed',\n", + " 'IsAirborne',\n", + " 'IsDucking',\n", + " 'IsScoped',\n", + " 'IsWalking',\n", + " 'EqValue',\n", + " 'HasHelmet',\n", + " 'HasDefuse',\n", + " 'DistToBombsiteA',\n", + " 'DistToBombsiteB',\n", + " 'Created',\n", + " 'Updated'],\n", + " usecols=['MatchId', 'RoundNum', 'Tick', 'Side', 'IsAlive'])\n", + "player_data['IsAlive'] = player_data['IsAlive'].astype(int)\n", + "results = pd.read_csv('G:/datasets/csgo/csgo_rounds_dust2.csv', usecols=['MatchId', 'RoundNum', 'WinningSide'])\n", + "player_data = pd.merge(player_data, results, on=['MatchId', 'RoundNum'])\n", + "del results\n", + "player_data['WinningSide'] = (player_data['WinningSide'] == 'CT').astype(int)\n", + "player_data = player_data.groupby(['MatchId',\n", + " 'RoundNum',\n", + " 'Side',\n", + " 'Tick',\n", + " ],\n", + " as_index=False).agg({'IsAlive': 'sum', 'WinningSide': 'max'})\n", + "player_data" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
IsAliveWinner
SideCTT
MatchIdRoundNumTick
41152575.05.01
152905.05.01
153235.05.01
153565.05.01
153895.05.01
..................
1892223065303.01.01
3065963.01.01
3066623.01.01
3067283.01.01
3067943.01.01
\n", + "

1069034 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " IsAlive Winner\n", + "Side CT T \n", + "MatchId RoundNum Tick \n", + "4 1 15257 5.0 5.0 1\n", + " 15290 5.0 5.0 1\n", + " 15323 5.0 5.0 1\n", + " 15356 5.0 5.0 1\n", + " 15389 5.0 5.0 1\n", + "... ... ... ...\n", + "1892 22 306530 3.0 1.0 1\n", + " 306596 3.0 1.0 1\n", + " 306662 3.0 1.0 1\n", + " 306728 3.0 1.0 1\n", + " 306794 3.0 1.0 1\n", + "\n", + "[1069034 rows x 3 columns]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pivoted = player_data.pivot_table(index=['MatchId', 'RoundNum', 'Tick'], values=['IsAlive'], columns=['Side'])\n", + "pivoted['Winner'] = player_data.groupby(['MatchId', 'RoundNum', 'Tick']).agg({'WinningSide': 'max'})['WinningSide']\n", + "pivoted" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Int64Index([ 4, 15, 21, 26, 28, 29, 37, 38, 41, 47,\n", + " ...\n", + " 1676, 1677, 1690, 1693, 1698, 1786, 1799, 1853, 1878, 1892],\n", + " dtype='int64', name='MatchId', length=173)" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pivoted.index.levels[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1069034 850701 218333\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "mask = np.random.choice(pivoted.index.levels[0], replace=False, size=int(pivoted.index.levels[0].shape[0] * 0.8))\n", + "pivoted.reset_index(drop=False, inplace=True)\n", + "train, test = pivoted[pivoted['MatchId'].isin(mask)], pivoted[~pivoted['MatchId'].isin(mask)]\n", + "print(pivoted.shape[0], train.shape[0], test.shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "\n", + "player_baseline = train.groupby([('IsAlive', 'CT'), ('IsAlive', 'T')], as_index=False).agg({('Winner', ''): 'mean'})\n", + "player_baseline.columns = ['CT count', 'T count', 'Win probability']\n", + "player_baseline_pivoted = player_baseline.pivot('CT count', 'T count', 'Win probability')" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sns.heatmap(data=player_baseline_pivoted[::-1], annot=True);" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CT countT countWin probability
00.00.00.701897
10.01.00.431849
20.02.00.069738
30.03.00.051923
40.04.00.013245
50.05.00.246753
61.00.00.767359
71.01.00.314527
81.02.00.063484
91.03.00.016426
101.04.00.003314
111.05.00.001309
122.00.00.903884
132.01.00.730738
142.02.00.323620
152.03.00.091989
162.04.00.065905
172.05.00.005273
183.00.00.922530
193.01.00.900248
203.02.00.606403
213.03.00.357637
223.04.00.160770
233.05.00.074659
244.00.00.918157
254.01.00.923495
264.02.00.751979
274.03.00.556022
284.04.00.378026
294.05.00.277427
305.00.00.875318
315.01.00.819805
325.02.00.853229
335.03.00.735363
345.04.00.580107
355.05.00.480956
\n", + "
" + ], + "text/plain": [ + " CT count T count Win probability\n", + "0 0.0 0.0 0.701897\n", + "1 0.0 1.0 0.431849\n", + "2 0.0 2.0 0.069738\n", + "3 0.0 3.0 0.051923\n", + "4 0.0 4.0 0.013245\n", + "5 0.0 5.0 0.246753\n", + "6 1.0 0.0 0.767359\n", + "7 1.0 1.0 0.314527\n", + "8 1.0 2.0 0.063484\n", + "9 1.0 3.0 0.016426\n", + "10 1.0 4.0 0.003314\n", + "11 1.0 5.0 0.001309\n", + "12 2.0 0.0 0.903884\n", + "13 2.0 1.0 0.730738\n", + "14 2.0 2.0 0.323620\n", + "15 2.0 3.0 0.091989\n", + "16 2.0 4.0 0.065905\n", + "17 2.0 5.0 0.005273\n", + "18 3.0 0.0 0.922530\n", + "19 3.0 1.0 0.900248\n", + "20 3.0 2.0 0.606403\n", + "21 3.0 3.0 0.357637\n", + "22 3.0 4.0 0.160770\n", + "23 3.0 5.0 0.074659\n", + "24 4.0 0.0 0.918157\n", + "25 4.0 1.0 0.923495\n", + "26 4.0 2.0 0.751979\n", + "27 4.0 3.0 0.556022\n", + "28 4.0 4.0 0.378026\n", + "29 4.0 5.0 0.277427\n", + "30 5.0 0.0 0.875318\n", + "31 5.0 1.0 0.819805\n", + "32 5.0 2.0 0.853229\n", + "33 5.0 3.0 0.735363\n", + "34 5.0 4.0 0.580107\n", + "35 5.0 5.0 0.480956" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "player_baseline" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CT countT counttruthWin probability
05.05.010.480956
15.05.010.480956
25.05.010.480956
35.05.010.480956
45.05.010.480956
...............
2179210.05.010.246753
2179220.05.010.246753
2179230.05.010.246753
2179240.05.010.246753
2179250.05.010.246753
\n", + "

217926 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " CT count T count truth Win probability\n", + "0 5.0 5.0 1 0.480956\n", + "1 5.0 5.0 1 0.480956\n", + "2 5.0 5.0 1 0.480956\n", + "3 5.0 5.0 1 0.480956\n", + "4 5.0 5.0 1 0.480956\n", + "... ... ... ... ...\n", + "217921 0.0 5.0 1 0.246753\n", + "217922 0.0 5.0 1 0.246753\n", + "217923 0.0 5.0 1 0.246753\n", + "217924 0.0 5.0 1 0.246753\n", + "217925 0.0 5.0 1 0.246753\n", + "\n", + "[217926 rows x 4 columns]" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test = test[[('IsAlive', 'CT'), ('IsAlive', 'T'), ('Winner', '')]]\n", + "test.columns = ['CT count', 'T count', 'truth']\n", + "\n", + "preds = pd.merge(test, player_baseline, on=['CT count', 'T count'])\n", + "preds" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7837721634898503" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import roc_auc_score\n", + "\n", + "roc_auc_score(preds['truth'], preds['Win probability'])" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6922854546956306" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import accuracy_score\n", + "\n", + "accuracy_score(preds['truth'], preds['Win probability'] > 0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "# WOAH. what is that??? that can't be right\n", + "\n", + "from sklearn.metrics import roc_curve\n", + "\n", + "fpr, tpr, _ = roc_curve(preds['truth'], preds['Win probability'])" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "plt.figure(figsize=(10, 7))\n", + "plt.plot(fpr, tpr)\n", + "plt.plot([0, 1], [0, 1], 'k--')\n", + "plt.xlabel('FPR')\n", + "plt.ylabel('TPR')\n", + "plt.xlim([0, 1])\n", + "plt.ylim([0, 1])\n", + "plt.title('ROC curve of players alive baseline model');" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MatchIdRoundNumTickCT countT countWinner
041152575.05.01
141152905.05.01
241153235.05.01
341153565.05.01
441153895.05.01
\n", + "
" + ], + "text/plain": [ + " MatchId RoundNum Tick CT count T count Winner\n", + "0 4 1 15257 5.0 5.0 1\n", + "1 4 1 15290 5.0 5.0 1\n", + "2 4 1 15323 5.0 5.0 1\n", + "3 4 1 15356 5.0 5.0 1\n", + "4 4 1 15389 5.0 5.0 1" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# examine what is up with the weird numbers for 0 ct's\n", + "\n", + "train.columns = ['MatchId', 'RoundNum', 'Tick', 'CT count', 'T count', 'Winner']\n", + "train.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MatchIdRoundNumTickCT countT countWinner
4974647193008350.05.00
4974847193009010.05.00
4974947193009340.05.00
4975047193009670.05.00
4975247193010330.05.00
1654882351114070.05.01
1654892351114730.05.01
1654902351115390.05.01
1654912351116050.05.01
1654922351116710.05.01
28142841210989680.05.00
28143041210992320.05.00
286213417253343490.05.00
286214417253344150.05.00
3377614741197710.05.01
3377624741198370.05.01
3377634741199030.05.01
4335135821192720.05.01
4335145821193380.05.01
4335155821194040.05.01
4335165821194700.05.01
4335175821195360.05.01
4335185821196020.05.01
5913899597990890.05.00
5913929597992870.05.00
5913939597993530.05.00
5913959597994850.05.00
5913969597995510.05.00
72813312711161770.05.01
72813412711162430.05.01
72813512711163090.05.01
72813612711163750.05.01
72813712711164410.05.01
72813812711165070.05.01
80155914191169650.05.01
80156014191170310.05.01
80156114191170970.05.01
80156214191171630.05.01
80156314191172290.05.01
8363871447121592220.05.00
8837131506243838230.05.00
8837141506243838560.05.00
8837151506243838890.05.00
8837171506243839550.05.00
8837181506243839880.05.00
8837211506243840870.05.00
8837221506243841200.05.00
8837231506243841530.05.00
\n", + "
" + ], + "text/plain": [ + " MatchId RoundNum Tick CT count T count Winner\n", + "49746 47 19 300835 0.0 5.0 0\n", + "49748 47 19 300901 0.0 5.0 0\n", + "49749 47 19 300934 0.0 5.0 0\n", + "49750 47 19 300967 0.0 5.0 0\n", + "49752 47 19 301033 0.0 5.0 0\n", + "165488 235 1 11407 0.0 5.0 1\n", + "165489 235 1 11473 0.0 5.0 1\n", + "165490 235 1 11539 0.0 5.0 1\n", + "165491 235 1 11605 0.0 5.0 1\n", + "165492 235 1 11671 0.0 5.0 1\n", + "281428 412 10 98968 0.0 5.0 0\n", + "281430 412 10 99232 0.0 5.0 0\n", + "286213 417 25 334349 0.0 5.0 0\n", + "286214 417 25 334415 0.0 5.0 0\n", + "337761 474 1 19771 0.0 5.0 1\n", + "337762 474 1 19837 0.0 5.0 1\n", + "337763 474 1 19903 0.0 5.0 1\n", + "433513 582 1 19272 0.0 5.0 1\n", + "433514 582 1 19338 0.0 5.0 1\n", + "433515 582 1 19404 0.0 5.0 1\n", + "433516 582 1 19470 0.0 5.0 1\n", + "433517 582 1 19536 0.0 5.0 1\n", + "433518 582 1 19602 0.0 5.0 1\n", + "591389 959 7 99089 0.0 5.0 0\n", + "591392 959 7 99287 0.0 5.0 0\n", + "591393 959 7 99353 0.0 5.0 0\n", + "591395 959 7 99485 0.0 5.0 0\n", + "591396 959 7 99551 0.0 5.0 0\n", + "728133 1271 1 16177 0.0 5.0 1\n", + "728134 1271 1 16243 0.0 5.0 1\n", + "728135 1271 1 16309 0.0 5.0 1\n", + "728136 1271 1 16375 0.0 5.0 1\n", + "728137 1271 1 16441 0.0 5.0 1\n", + "728138 1271 1 16507 0.0 5.0 1\n", + "801559 1419 1 16965 0.0 5.0 1\n", + "801560 1419 1 17031 0.0 5.0 1\n", + "801561 1419 1 17097 0.0 5.0 1\n", + "801562 1419 1 17163 0.0 5.0 1\n", + "801563 1419 1 17229 0.0 5.0 1\n", + "836387 1447 12 159222 0.0 5.0 0\n", + "883713 1506 24 383823 0.0 5.0 0\n", + "883714 1506 24 383856 0.0 5.0 0\n", + "883715 1506 24 383889 0.0 5.0 0\n", + "883717 1506 24 383955 0.0 5.0 0\n", + "883718 1506 24 383988 0.0 5.0 0\n", + "883721 1506 24 384087 0.0 5.0 0\n", + "883722 1506 24 384120 0.0 5.0 0\n", + "883723 1506 24 384153 0.0 5.0 0" + ] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train[(train['CT count'] == 0) & (train['T count'] == 5)]" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transforming raw data...\n", + "Transforming 1495/1495: 1649, de_dust2, 22\n", + "Done!\n" + ] + } + ], + "source": [ + "from csgo_wp.data_transform import transform_data, CSGODataset\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "dataset = CSGODataset(transform=transform_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([413319, 12, 10])\n" + ] + } + ], + "source": [ + "print(dataset.data.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transforming raw data...\n", + "Transforming 642/642: 1786, de_dust2, 27\n", + "Done!\n", + "Transforming raw data...\n", + "Transforming 603/603: 1799, de_dust2, 30\n", + "Done!\n" + ] + } + ], + "source": [ + "val_dataset = CSGODataset(transform=transform_data, dataset_split='val')\n", + "test_dataset = CSGODataset(transform=transform_data, dataset_split='test')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([176370, 12, 10])\n", + "torch.Size([164819, 12, 10])\n" + ] + } + ], + "source": [ + "print(val_dataset.data.shape)\n", + "print(test_dataset.data.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5100189291896169\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "\n", + "from sklearn.linear_model import LogisticRegression\n", + "\n", + "lr = LogisticRegression(random_state=13)\n", + "lr.fit(dataset.data.view(-1, 120), dataset.targets)\n", + "preds = lr.predict_proba(val_dataset.data.view(-1, 120))\n", + "\n", + "from sklearn.metrics import roc_curve, roc_auc_score\n", + "\n", + "fpr, tpr, _ = roc_curve(val_dataset.targets, preds[:, 0])\n", + "\n", + "print(roc_auc_score(val_dataset.targets, preds[:, 0]))\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "plt.figure(figsize=(10, 7))\n", + "plt.plot(fpr, tpr)\n", + "plt.plot([0, 1], [0, 1], 'k--')\n", + "plt.xlabel('FPR')\n", + "plt.ylabel('TPR')\n", + "plt.xlim([0, 1])\n", + "plt.ylim([0, 1])\n", + "plt.title('ROC curve of logistic regression baseline model');" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 100 candidates, totalling 500 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 26 tasks | elapsed: 6.1min\n", + "[Parallel(n_jobs=-1)]: Done 176 tasks | elapsed: 33.3min\n", + "[Parallel(n_jobs=-1)]: Done 426 tasks | elapsed: 76.2min\n", + "[Parallel(n_jobs=-1)]: Done 500 out of 500 | elapsed: 88.2min finished\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5017980159110792\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from sklearn.model_selection import RandomizedSearchCV\n", + "from scipy.stats import loguniform, uniform\n", + "\n", + "params = {'C': loguniform(1e-3, 1e1),\n", + " 'l1_ratio': uniform(),\n", + " }\n", + "\n", + "lr = LogisticRegression(random_state=13, solver='saga', penalty='elasticnet')\n", + "\n", + "cv = RandomizedSearchCV(lr, params, random_state=3, n_iter=100, n_jobs=-1, verbose=1, scoring='roc_auc')\n", + "cv.fit(dataset.data.view(-1, 120), dataset.targets)\n", + "\n", + "preds = cv.best_estimator_.predict_proba(val_dataset.data.view(-1, 120))\n", + "\n", + "fpr, tpr, _ = roc_curve(val_dataset.targets, preds[:, 1])\n", + "\n", + "print(roc_auc_score(val_dataset.targets, preds[:, 1]))\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "plt.figure(figsize=(10, 7))\n", + "plt.plot(fpr, tpr)\n", + "plt.plot([0, 1], [0, 1], 'k--')\n", + "plt.xlabel('FPR')\n", + "plt.ylabel('TPR')\n", + "plt.xlim([0, 1])\n", + "plt.ylim([0, 1])\n", + "plt.title('ROC curve of logistic regression baseline model');" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LogisticRegression(C=0.02305440604242659, class_weight=None, dual=False,\n", + " fit_intercept=True, intercept_scaling=1,\n", + " l1_ratio=0.06467319799028326, max_iter=100,\n", + " multi_class='auto', n_jobs=None, penalty='elasticnet',\n", + " random_state=13, solver='saga', tol=0.0001, verbose=0,\n", + " warm_start=False)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cv.best_estimator_" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 100 candidates, totalling 500 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 38 tasks | elapsed: 1.7min\n", + "[Parallel(n_jobs=-1)]: Done 188 tasks | elapsed: 9.2min\n", + "[Parallel(n_jobs=-1)]: Done 438 tasks | elapsed: 22.1min\n", + "[Parallel(n_jobs=-1)]: Done 500 out of 500 | elapsed: 25.1min finished\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5018716279161413\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from sklearn.model_selection import RandomizedSearchCV\n", + "from scipy.stats import loguniform, uniform\n", + "from sklearn.feature_selection import VarianceThreshold\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "params = {'C': loguniform(1e-3, 1e1),\n", + " 'l1_ratio': uniform(),\n", + " }\n", + "\n", + "scaler = StandardScaler()\n", + "thresh = VarianceThreshold() # remove constant features\n", + "\n", + "thresh.fit(dataset.data.view(-1, 120))\n", + "scaler.fit(thresh.transform(dataset.data.view(-1, 120)))\n", + "\n", + "lr = LogisticRegression(random_state=13, solver='saga', penalty='elasticnet')\n", + "\n", + "cv = RandomizedSearchCV(lr, params, random_state=3, n_iter=100, n_jobs=-1, verbose=1, scoring='roc_auc', pre_dispatch='n_jobs')\n", + "cv.fit(scaler.transform(thresh.transform(dataset.data.view(-1, 120))),\n", + " dataset.targets)\n", + "\n", + "preds = cv.best_estimator_.predict_proba(scaler.transform(thresh.transform(val_dataset.data.view(-1, 120))))\n", + "\n", + "fpr, tpr, _ = roc_curve(val_dataset.targets, preds[:, 1])\n", + "\n", + "print(roc_auc_score(val_dataset.targets, preds[:, 1]))\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "plt.figure(figsize=(10, 7))\n", + "plt.plot(fpr, tpr)\n", + "plt.plot([0, 1], [0, 1], 'k--')\n", + "plt.xlabel('FPR')\n", + "plt.ylabel('TPR')\n", + "plt.xlim([0, 1])\n", + "plt.ylim([0, 1])\n", + "plt.title('ROC curve of logistic regression baseline model');" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5511878437375971" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cv.best_estimator_.score(scaler.transform(thresh.transform(val_dataset.data.view(-1, 120))), val_dataset.targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEHCAYAAABvHnsJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de5wcZZ3v8c9vei7MDAMZkyFiJpDAgWhWIZoBwXhclYUFRDCCKyhEOWeBbERx96jR4xE567oaOB7RxWwQX4goBxQwiHhBl4V1Ny7KBAG5GA0gZkDCEAcSJpPpzPTv/FHVk55JX6o7XX39vl+vvDLdXZdfP1Vdv6rneeopc3dERERK0VLtAEREpH4piYiISMmUREREpGRKIiIiUjIlERERKVlrtQOohDlz5viCBQuqHYaISF3ZuHHj8+7el2+apkgiCxYsYHBwsNphiIjUFTN7qtA0qs4SEZGSKYmIiEjJlERERKRkSiIiIlIyJRERESlZU/TOiksq5WwbTZKcmKS9NcHs7nZaWizn+6UuE+D50XF27Z4kYUZne4JZndGXGXW9L4wlGUtOMunOfm0J5nR3lLyOQmWwL2VUzHpKlbnczvYEEyln90Rqn9dRTLyZ07a1ttDaYowly/s9yy2u7bGv65k5fW9nGyNju4uKM3MZZkbCoKWlpaTvWKlyqgQlkSKkUj51MG81Y+fuSd7/9fsYGhmjv7eTa1YMcPCBHTw9sosv3fVbzlw6n9nd7YwlJ3jFgZ20tua/8EulnE1bd3DB9YNTy7zu/GNITqS48l/2LO9l3e2Mjk8wb1ZXUT+cXDtqKuX8ftsoW7fv4qO3PDTt+yya27PXPLkS3bbRJKlUCjNjeMc4F31r47RlHdG3/9QPdzLl/MMPHuUnjz6Xd135vlcqleL5l5J7rSfqcvItP70d+vbv4GMnL4pcLi+MJZmYSJFMOZMppz3RQt/+HbS2tmTdvvmWNXPaK846ist/vInhl8a5ZsUAcw/oiJRUyn3AynfylOv7AWU7aRgZG+ePL4yzMuJ2nxnXSYsP4kMnHBl5/sxl3Hb/Fs4aOIREi9GWaOGuXw/x+sP7itrn0r+3p7btpKs9wc7kJIfO7mLB7O6ynwDta7KLwpphKPiBgQEv5T6RmWejW18c54Jv7v2j/tWWFwDCg/6xXP7jx3jfGxay+tY9B56rz1vKq15+QN6NOLxjnOVrNzA0Mjb13tfffww3/vKpvZa39r2vo7+3k5d1d+SMPeoBa3jHOA8//SKf+t7D09bd39vJ+lXL6OvpKLjcjtYWPv+j4HsnJ1J7LeukxQdxyV8cyUXf3PPDXXPmUfyfO4Pyy7auQt/rU6ct5jN3PFow5mJlboerz1saaR3pA8NYcoLtuyamJZ115y7llXN7GBnbvdf2zRVvtn2hv7eTT522mIu+uZH+3k4+c8arOf+6+4pORvuSaPMtb9toMmvMt1+8jK3bx/c5hvS6n31xV6R9NW1mWUbdpjOX8bWfbeZtR89j1Q33T32Pf37vUu54cIi/ftN/ibzP/Wl0nE3P7pi2j1xx1lEsenkPszrb9ynZZts+a848im/8/En+9sRFRZe5mW1094F806hNJIf0xli+dgPL1tzNg1tenEogAEMjY3z0lodY+ebDp+YZGhmjxeDMpfOnDvjp9y/65ka2jSbzrjM5MTltxwboak9kXd6qG+5nLDmZc1nbRpNTO1J6nguuH8waQ3Jikq72xF7rHhoZIzkxfR25lvvUtp1TcWZb1plL508lkPR8q2/dU37Z1lXoe83qbIsUc7Eyt0PUdWwbTfLUtp08tyM5dXBIT7vyWxt57qXxrNs3V7y5pp3V2Tb1d1d7YurvXNu2mP0ginzLyxXzWHIy5zyplDO8Y5ynR3YyvGOcVCr3SW163VH3VQh+x2O7J6ZNX8p+k5yY5KyBQ6YSSHqev7lhI2cNHFLUPjeWnNxrH/noLQ+xKzk57ZizfO0GNm3dkbdMZsq2fVbf+hBnLp2/T9s9HyWRHGZujFw7bvpHDcHZzGTKmd3dXtLBrb01QX9v57T3diYncy5vMs/OVcwBq701uKSeue7+3k7aWxORltvVnpj6cb4wtnuvZeX6Dunyy7auQt8r23qiLiefzO2Qax1mNu3Al07EufaTiclU1u2bK95c074wtnuvv9PrKCYZlZpo8y0vV8yT7jnnKeagmV531O2ePhF8/LnRadOXst+0tyZoTVjW75FosaL2uVzlMZHyfU74+U4+ynGClY2SSA4zN0auHW9neDWQviS95mdP8LLu9pIObrO727lmxcDUvP29nfR2tzF7/+zLa81zWVrMAWt2dzuHzu7iirOOmrbua1YMTLV3FFruzuTkVBmtu+dx1pw5fVkH9XTkPCjmWleh75VtPVeftzTScvLJ3A7r7nl8r3JZd+5SLrv94WkHvs6wbjtXMm5NtGTdvrm+d7ZprzjrKNbd8/hUdea6ex6fto5iklGpiTbf8nJ9v/3ass9jZkUdNNPrjrrd0yeCX77rd9Omv3XjFtaduzTSdkjr7WyjLdGS47u3FLXP5SqPREv2JFXMgT/fyUc5TrCyUZtIDjPrUV87fxZ/f8af8TcZ9aFfeNfRpNyZN6uT/doTvLRrghXX/pI3HDabc48/dFrd6dXnLuVVB+dvE4E9jfc7xyd58vlRvnzX7zh2wSzevqR/WkPgF//qaA4/aP+ytImkG4R3T6TYnXJSeXpnRWkTWX3rQ/Tt38GHTjiChXO66epI8LLOdn43/NK0+a4+bylzutuLavRLpZzH/rh9qjH9pMUH8fFTXsWOXRPM6mpjXoQODFHk6p1lZlx2+8P85NHnpqbt7+3ku6vewEu7JvK2iaQb14vpnZXeF57dvgt3py3RwkE9HZjBOdf8ItK2rVSbSK6eiUDWeV7W1cbrP/eve61jw+q3MK+3K++60/vXobO7aEu0cPAB++213Z8e2cmyNXcDwe935ZsPZ1ZnG/29nczt2a+o3lk520TOXcorD9qftrboB+dcZTj3gA5Ovypam1kxy467TURJJIdsG+Pmi47jN8++RFd7ghfGdrPunscZfml8aiOnUs6z23fxzAtj7J5MMZkKfvQ7k5McPf/AnAf8XOvP7N45vnuSzc+NFtWbI8oBq5SDTKHeWZMO7r7XOsvVS+hPo+M8uOVFZnW1sX9HK7t2T/L8S8miy7gUmQemTBtWv4WDD+yc1jsrFW7/dO+sUuUr72K7CsfZO6vYeXI1xOc7aO7pIZkiYeTt7p6rc0IpHS/S2/2vlvZzwZsOI9FiTKac3q425vTsV9Sy0t8jarItpQNCuXpnKYmEytE7K923fObZ9MyNXO4zv1yxlKu7Xjl/aJUSVxlHkau8vnPR8VkTp+QW93Ys5/Ir9TuptftHlERCpSaRbKKe3dfSjpBPvjPrbFUKtaJaZZztwLTu3KV8+a7flnTPS7OLezuW86bWYqvxGmH7K4mEyplEGk09XolU28zqgmxtJCq/xlPKTZb1nkh0n4gU1NvZxtXnFddTpdm1tBh9PR3M6+3C3aclECjPvSpSezK3e1/Pnk4n5b4Xp95o2JMmlko5vxt+iS/9y2/51GmLmd3dzkE9HbziwM66P4OqlHSXyplXcnF0pZTaVO57ceqNkkgTyzyDSp9NqyqmOOl7I2ZWZehKrnk0+4mEkkgTa/YzqHJoaTEWze1h/aplDdeoKtE0+4mEkkgTa/YzqHJJ15VLc2r2Ewk1rDexYobhECmkmMEUG02uRvdmoCuRJtbsZ1BRNeo9AOXUyN1cJT8lkSanqpj8SUIHx2hydXNVJ43Gp+osaWoznxszczjyeroHoJrVSeqk0byURBpArdZF12pcmQoliXo5OBZKhnEr95Dzsrda/T0pidS5ah886i2umQoliXo5OFb7ikmdNOJVy78nJZE6V+2DR73FNVOhJFEvB8dqXzFldtLYsPotrF+1TO1GZVTLvyc1rMeoEr16qn3wyKVW45qp0I1i9dKDrRbu+VEnjfjU8u9JSSQmlerVUwsHj2xqNa6ZoiSJejg4Nvtd042uln9PGgo+JpV8iE0tdkGt1bgame5naVzV+j3peSKhaiSRSj7sqdDBo5oPcKq3g1o9xizNoRr7ZpQkouqsmFTy8jNfdUs1rwjqoRooU6XLSglLilGrvyf1zopJrfTqqeVeHbUmSlmVq69+LXfZFCmGrkQiKOWMsVZ69dRyr45aU6isynmlomFCpFHoSqSAfTljrIWRPUu9Wa5W746NU6GyKudVnZK7NAolkQLqvTqolGq1ZqlqmZkoezvb8pZVOQ/89XInvNSuWjnRU3VWAfV+xlhKtVozVLXkqpo6om//nGVVzs4Suq9D9kUtdaFXEimgvTXBSYsP4syl85nV2cYLY7u5deOWujpjLLZXR70nzihKSZTlPPDXSpuZ1KdaOtFTEimgt7OND51wJCu/tXHqwLHu3KX0drZVO7TY1PLdseVSSqIs94G/VrtsSu2rpRO9WNtEzOxkM9tkZpvN7ON5pjvGzCbN7Kzw9SIzeyDj33Yz+/CMeT5iZm5mc+L8DiNju6cSCAQbauW3NjIytjvO1VZVrXRPjlOpbRK10FlCpJba1GK7EjGzBPAV4ERgCLjPzG5390ezTLcGuDP9nrtvApZkfP40sD5jnvnhcv8QV/xptZTxK6UZqlrUJiH1rJb23zirs44FNrv7EwBmdhNwBvDojOk+CNwKHJNjOScAj7v7UxnvfRH4GPC9skacRTNU7WTT6FUtzZAopXHV0v4bZ3XWPGBLxuuh8L0pZjYPWA6sy7Ocs4EbM+Y5HXja3R/Mt3Izu9DMBs1scHh4uNjYpzRD1U6zUtWU1LNa2X/jvBLJ9o1mdmS+Eljt7pNme09uZu3A6cAnwtddwCeBkwqt3N2/CnwVggEYi4o8Qy1lfJF6prHCGlOcSWQImJ/xuh94ZsY0A8BNYQKZA5xqZhPuflv4+SnA/e6+NXx9OLAQeDCcpx+438yOdfdn4/kajV+1I1KMUpJBLd3XIOUVZxK5DzjCzBYSNIyfDbwncwJ3X5j+28yuA+7ISCAA55BRleXuvwYOypjn98CAuz8fQ/wiMkOpyaCW7muQ8oqtTcTdJ4CLCXpdPQZ8x90fMbOVZray0Pxh1dWJwHfjilFEilPqMEDN2MuxWcR6s6G7/xD44Yz3sjaiu/v7Z7zeCcwusPwF+xahiBSj1GTQrL0cm4EGYJSi1Mqgb1Idpd7kpl6OjUuPx5XI1Dgq+7IPqHdW/dEz1kNKIuUxvGOc5Ws37FUlocbR5tJsyaDZvm8mPWNdykqNowLN1eVdV9+FqU1EIqulQd9EKqHeH0pXCUoiEpkaR6XZ6Oq7MFVnSWQaAkaajbomF6YrESlKrQz6JlIJuvouTFciIiI56Oq7MCWRJtLMXRVFStVMvdFKoSTSJNRVUUTioDaRJqGuiiISByWRJqGuiiISByWRJqEbBUUkDkoiTUJdFUUkDmpYbxLqqigicVASaSLqqigi5abqLBERKZmSiIiIlKxgErHAuWZ2afj6EDM7Nv7QRESk1kW5ElkLHA+cE77eAXwltohERKRuRGlYf727v87MfgXg7iNmpn6hIiIS6Upkt5klAAcwsz4gFWtUIiJSF6IkkS8D64GDzOyzwH8A/xhrVCIiUhfyVmeZWQvwJPAx4ATAgHe4+2MViE1ERGpc3iTi7ikz+4K7Hw/8pkIxiYhInYhSnfUTMzvTzDQ+hoiITBOld9bfAd3ApJntCt9zdz8gvrBERKQeFEwi7t5TiUBERKT+RBqA0cxOB94UvrzH3e+ILyQREakXUYY9+TxwCfBo+O+S8D0REWlyUa5ETgWWuHsKwMy+AfwK+HicgUl5pFLOttGkniEiIrGI+jyRWcCfwr8PjCkWKbNUytm0dQcXXD/I0MjY1NMMF83tUSIRkbKI0sX3c8CvzOy68CpkI7pjvS5sG01OJRCAoZExLrh+kG2jySpHJiKNIkrvrBvN7B7gGII71le7+7NxByb7LjkxOZVA0oZGxkhOTFYpIhFpNFEa1pcDO939dnf/HrDLzN4Rf2iyr9pbE/T3dk57r7+3k/bWRJUiEpFGE6U669Pu/mL6hbu/AHw6vpCkXGZ3t3PNioGpRJJuE5ndrZH8pfalUs7wjnGeHtnJ8I5xUimvdkiSRZSG9WyJJur9JScDXwISwNfcPWvXYDM7BrgXeLe732Jmi4BvZ0xyGHCpu19pZp8BziAYjv454P3u/kyUeJpNS4uxaG4P61ctU+8sqSvqFFI/olyJDJrZ/zWzw83sMDP7IkHjel7hM0i+ApwCLAbOMbPFOaZbA9yZfs/dN7n7EndfAiwFdhIMRw9whbsfFX52B3BphO/QtFpajL6eDub1dtHX06EfoNQFdQqpH1GSyAeBJMGVwc3ALuADEeY7Ftjs7k+4exK4ieAKItvybyW4qsjmBOBxd38KwN23Z3zWTfiwLBFpHOoUUj+i9M4aJbyxMLxq6A7fK2QesCXj9RDw+swJzGwesBx4K0Hvr2zOBm6cMd9ngRXAi8BbIsQiInUk3SkkM5FUq1OIbtjNL0rvrP9nZgeYWTfwCLDJzD4aYdnZSnnmVcOVBF2Gs55ehM9yP53gCmjPQtw/6e7zgRuAi3PMe6GZDZrZ4PDwcIRwRaRW1EqnkHTbzPK1G1i25m6Wr93Apq071MifwdzzF4aZPeDuS8zsvQTtE6uBje5+VIH5jgcuc/e/DF9/AsDdP5cxzZPsSTZzCNo+LnT328LPzwA+4O4n5VjHocAP3P3V+WIZGBjwwcHBvN9TRGpLLVwBDO8YZ/naDXtdEa1ftYy+no6KxlINZrbR3QfyTROll1WbmbUB7wCucvfdZhYlDd8HHGFmC4GnCaql3pM5gbsvzAj2OuCOdAIJncPeVVlHuPvvwpenoycuijSkdKeQalLbTGFRksjVwO+BB4GfhWf/2/POAbj7hJldTNDrKgFc6+6PmNnK8PN1+eY3sy7gROCiGR99PuwCnAKeAlZG+A4iIkWrpbaZWlWwOmuvGYLH5CbcfSJ8/T53/0YcwZWLqrNEpBTNfr9KuaqzpvEg60xkvHUJUNNJRESkFLpht7Cik0gWKk0RaVi10DZTy8qRRNTXTUTKqhZ6Zkk0uhIRkZrS7O0Q9SbKzYYLC7y3oawRiUhT07hZ9SXK2Fm3ZnnvlvQf7p71jnERkVLo3oz6krM6y8xeCfwZcKCZvTPjowOA/eIOTESak+7NqC/5rkQWAacBs4C3Z/x7HXBB/KFJNemBQFIttTJulkQTZeys4939PysUTyx0s2Fx1LAp1abeWbUhys2GUdpEloej+LaZ2V1m9ryZnVumGKUGqWFTqk0PU6sfUZLISeGDoE4jeCbIkUCUoeClTqlhU0SiipJE2sL/TwVudPc/xRiP1IB0w2YmNWyKSDZRksj3zew3wABwl5n1ETwiVxqUGjZFJKpIo/iaWS+w3d0nwycc9rj7s7FHVyZqWC+eGjZFpCyj+IbP9fgAcAhwIfAKgu6/d5QjSKlNGnRORKKIUp31dSAJvCF8PQT8Q2wRiYhI3YiSRA5398uB3QDuPoYGXRQREaIlkaSZdRIO+W5mhwPjsUYlIiJ1IcpQ8JcBPwbmm9kNwDLg/DiDEhGR+lAwibj7T8xsI3AcQTXWJe7+fOyRiYhIzYvyPJG73H2bu//A3e9w9+fN7K5KBCciIrUt31Dw+wFdwJzwPpF0Y/oBBN18RUSkyeWrzroI+DBBwtjIniSyHfhKzHGJiEgdyJlE3P1LwJfM7IPu/k+5pjOzE939p7FEJyIiNa1gm0i+BBJaU6ZYRESkzkS5T6QQ3XgoItKkypFE9NxUEZEmVY4kIiIiTaocSeT3ZViGiIjUoSjDnmBmbwAWZE7v7teH/78zlshERKTmRXmeyDeBw4EHgPRDth24Psa4RESkDkS5EhkAFnuURyCKiEhTiZJEHgZeDvwx5lhEJAI9ulhqSZQkMgd41Mx+ScZzRNz99NiiEpGsUiln09YdXHD9IEMjY/T3dnLNigEWze1RIpGqiPo8ERGpAdtGk1MJBGBoZIwLrh9k/apl9PV0VDk6aUZRnifyb5UIREQKS05MTiWQtKGRMZITkznmkEZWC1WbUZ4ncpyZ3WdmL5lZ0swmzWx7JYITkenaWxP093ZOe6+/t5P21kSVIpJqSVdtLl+7gWVr7mb52g1s2rqDVKqyfaCi3Gx4FXAO8DugE/jr8D0RqbDZ3e1cs2JgKpGk20Rmd7dXOTKptFxVm9tGkxWNI9LNhu6+2cwS7j4JfN3Mfh5lPjM7GfgSkAC+5u6fzzHdMcC9wLvd/RYzWwR8O2OSw4BL3f1KM7sCeDuQBB4Hznf3F6LEI1LvWlqMRXN7WL9qmXpnNblaqdqMciWy08zagQfM7HIz+1ugu9BMZpYgeHjVKcBi4BwzW5xjujXAnen33H2Tuy9x9yXAUmAnsD78+KfAq939KOC3wCcifAeRhtHSYvT1dDCvt4u+ng4lkCZVK1WbUZLIeeF0FwOjwHzgzAjzHQtsdvcn3D0J3ASckWW6DwK3As/lWM4JwOPu/hSAu//E3SfCz+4F+iPEIiLSUGqlajNK76ynzKwTONjd/3cRy54HbMl4PQS8PnMCM5sHLAfeChyTYzlnAzfm+Oy/Mb3aK3PZFwIXAhxyyCGRgxYRqQe1UrUZpXfW2wnGzfpx+HqJmd0eYdnZvsnMbgNXAqvDtpZs624HTgduzvLZJ4EJ4IZs87r7V919wN0H+vr6IoQrIlJfaqFqM+rNhscC9wC4+wNmtiDCfEMEVV9p/cAzM6YZAG4yMwjujD/VzCbc/bbw81OA+919a+ZMZvY+4DTgBI3pJSJSPVGSyIS7vxge6ItxH3CEmS0EniaolnpP5gTuvjD9t5ldB9yRkUAg6Fo8rSor7PG1Gvhzd99ZbFAiIlI+kQZgNLP3AAkzOwL4EFCwi6+7T5jZxQS9rhLAte7+iJmtDD9fl29+M+sCTgQumvHRVUAH8NMwsd3r7isjfA8RESkzK1QbFB7MPwmcFL51J/AZdx/PPVdtGRgY8MHBwWqHISJSV8xso7sP5JsmShffxeG/VmA/gm669+17eCIiUu+iVGfdAHyE4LkiqXjDEWkstTBAnkicoiSRYXf/fuyRiDQYPftDmkGU6qxPm9nXzOwcM3tn+l/skYnUuVoZIE8kTlGuRM4HXgm0sac6y4HvxhWUSCOolQHyROIUJYkc7e6viT0SkQaTHiAvM5Ho2R/SaKJUZ92bbfRdEcmvVgbIE4lTlCuRNwLvM7MngXGCMbE8HIpdRHKolQHyROIUJYmcHHsUIg0qPUCeSKOKNBR8JQIREZH6E6VNREREJCslERERKVmUNhGRpqVhS0TyUxIRyUHDlogUpuoskRw0bIlIYUoiIjlo2BKRwpRERHJID1uSScOWiEynJCKSg4YtESlMDesiOWjYEpHClERE8tCwJSL5qTpLRERKpiQiIiIlUxIREZGSKYmIiEjJlERERKRkSiIiIlIyJRERESmZkoiIiJRMSUREREqmJCIiIiVTEhERkZJp7CyRIumRuSJ7KImIFEGPzBWZTtVZIkXQI3NFplMSESmCHpkrMp2SiEgR9MhckemURESKoEfmikynhnWRIuiRuSLTxXolYmYnm9kmM9tsZh/PM90xZjZpZmeFrxeZ2QMZ/7ab2YfDz95lZo+YWcrMBuKMXySb9CNz5/V20dfToQQiTS22KxEzSwBfAU4EhoD7zOx2d380y3RrgDvT77n7JmBJxudPA+vDjx8G3glcHVfsIiISTZxXIscCm939CXdPAjcBZ2SZ7oPArcBzOZZzAvC4uz8F4O6PhUlGRESqLM4kMg/YkvF6KHxvipnNA5YD6/Is52zgxmJXbmYXmtmgmQ0ODw8XO7uIiEQQZxLJVlHsM15fCax296yd7M2sHTgduLnYlbv7V919wN0H+vr6ip1dREQiiLN31hAwP+N1P/DMjGkGgJvMDGAOcKqZTbj7beHnpwD3u/vWGOMUEZESxZlE7gOOMLOFBA3jZwPvyZzA3Rem/zaz64A7MhIIwDmUUJUlIiKVEVt1lrtPABcT9Lp6DPiOuz9iZivNbGWh+c2si6Bn13dnvL/czIaA44EfmNmd2eYXEZH4mfvMZorGMzAw4IODg9UOQ0SkrpjZRnfPez+ehj0REZGSKYmIiEjJlERERKRkSiIiIlIyJRERESmZkoiIiJRMSUREREqmh1KJ1KlUytk2mtTDsaSqlERE6lAq5WzauoMLrh9kaGRs6jG9i+b2KJFIRak6S6QObRtNTiUQgKGRMS64fpBto8kqRybNRklEpA4lJyanEkja0MgYyYmsT1UQiY2SiEgdam9N0N/bOe29/t5O2lsTVYpImpWSiEgdmt3dzjUrBqYSSbpNZHZ3e5Ujk2ajhnWROtTSYiya28P6VcvUO0uA6vXWUxIRqVMtLUZfT0e1w5AixXGwr2ZvPVVniYhUSPpgv3ztBpatuZvlazewaesOUql9e65TNXvrKYmIiFRIXAf7avbWUxIREamQuA721eytpyQiIlIhcR3sq9lbT89YFxGpkDgbwONosI/yjHX1zhIRqZA4u2ZXq7eekoiISAU1WtdstYmIiEjJlERERKRkSiIiIlIyJRERESmZkoiIiJSsKe4TMbNh4Klqx1Flc4Dnqx1EDVA5BFQOAZVDIFc5HOrufflmbIokImBmg4VuGmoGKoeAyiGgcgjsSzmoOktEREqmJCIiIiVTEmkeX612ADVC5RBQOQRUDoGSy0FtIiIiUjJdiYiISMmUREREpGRKIg3GzE42s01mttnMPp7l8/ea2UPhv5+b2dHViDNuhcohY7pjzGzSzM6qZHyVEqUczOzNZvaAmT1iZv9W6RgrIcLv4kAz+76ZPRiWw/nViDNOZnatmT1nZg/n+NzM7MthGT1kZq+LtGB3178G+QckgMeBw4B24EFg8Yxp3gD0hn+fAvyi2nFXoxwypvtX4IfAWdWOu0r7wyzgUeCQ8PVB1Y67SuXwP4E14d99wJ+A9mrHXuZyeBPwOuDhHJ+fCtLUXkQAAAUqSURBVPwIMOC4qMcGXYk0lmOBze7+hLsngZuAMzIncPefu/tI+PJeoL/CMVZCwXIIfRC4FXiuksFVUJRyeA/wXXf/A4C7N2JZRCkHB3rMzID9CZLIRGXDjJe7/4zge+VyBnC9B+4FZpnZwYWWqyTSWOYBWzJeD4Xv5fLfCc48Gk3BcjCzecByYF0F46q0KPvDkUCvmd1jZhvNbEXFoqucKOVwFfAq4Bng18Al7p6qTHg1o9jjB6AnGzaabM/YzNqH28zeQpBE3hhrRNURpRyuBFa7+2Rw8tmQopRDK7AUOAHoBP7TzO5199/GHVwFRSmHvwQeAN4KHA781Mz+3d23xx1cDYl8/MikJNJYhoD5Ga/7Cc6spjGzo4CvAae4+7YKxVZJUcphALgpTCBzgFPNbMLdb6tMiBURpRyGgOfdfRQYNbOfAUcDjZREopTD+cDnPWgc2GxmTwKvBH5ZmRBrQqTjx0yqzmos9wFHmNlCM2sHzgZuz5zAzA4Bvguc12Bnm5kKloO7L3T3Be6+ALgFWNVgCQQilAPwPeC/mlmrmXUBrwceq3CccYtSDn8guBrDzOYCi4AnKhpl9d0OrAh7aR0HvOjufyw0k65EGoi7T5jZxcCdBD1SrnX3R8xsZfj5OuBSYDawNjwLn/AGG8U0Yjk0vCjl4O6PmdmPgYeAFPA1d8/aBbReRdwfPgNcZ2a/JqjWWe3uDTVEvJndCLwZmGNmQ8CngTaYKoMfEvTQ2gzsJLg6K7zcsGuXiIhI0VSdJSIiJVMSERGRkimJiIhIyZRERESkZEoiIiJSMiUREREpmZKISJWF41YNhH9/1sy2mNlLZVz+EjM7NeP16fmGxxcphpKISERmVombc79PMOpsUQrEtoTgJjIA3P12d/98CbGJ7EV3rEtTMbMFwI+BXwCvJRgjagXwEeDtBIMQ/hy4yN3dzO4JXy8Dbjez3wL/i+C5FNuA97r7VjO7DFgIHEwwMu7fETyT4RTgaeDt7r67UHzhENxEGRTSzK4jGNr7tcD9ZvZtgoElO4ExgjuOnwT+Hug0szcCnws/H3D3i83sUOBagmdoDAPnp4eFF4lCVyLSjBYBX3X3o4DtwCrgKnc/xt1fTXCQPS1j+lnu/ufu/gXgP4Dj3P21BM+l+FjGdIcDbyN4LsO3gLvd/TUEB/S3xfRdjgT+wt3/B/Ab4E1hbJcC/xg+P+NS4NvuvsTdvz1j/qsIniFxFHAD8OWY4pQGpSsRaUZb3H1D+Pe3gA8BT5rZx4Au4GXAIwRVSwCZB95+4Nvhw3raCc70037k7rvD8ZcSBFc8EDyfYkEcXwS42d0nw78PBL5hZkcQDOHdFmH+44F3hn9/E7i8/CFKI9OViDSjmQPGObCW4BG5rwGuAfbL+Hw04+9/IrhqeQ1w0YzpxgHChxnt9j0D06WI74QtM7bPEFz9vJqgam6/7LPkpcH0pChKItKMDjGz48O/zyGoogJ43sz2B87KM++BBG0cAO+LKb5SZcb2/oz3dwA9Oeb5OcHQ6ADvZU9ZiESiJCLN6DHgfWb2EEHV1T8TXH38GriN4PkTuVwG3Gxm/w6UfahwM7s8HKa7y8yGwgb7qC4HPmdmGwiq09LuBhab2QNm9u4Z83wIOD8si/OAS/YhfGlCGgpemkrYO+uOsMpHRPaRrkRERKRkuhIRqRAzW09wL0mm1e5+Z4H5Pgm8a8bbN7v7Z8sZn0gplERERKRkqs4SEZGSKYmIiEjJlERERKRkSiIiIlKy/w/+15qV9k/f0AAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import pandas as pd\n", + "import seaborn as sns\n", + "\n", + "results = pd.DataFrame.from_dict(cv.cv_results_)\n", + "sns.scatterplot(x='param_C', y='mean_test_score', data=results)\n", + "plt.xscale('log')\n", + "plt.show()\n", + "sns.scatterplot(x='param_l1_ratio', y='mean_test_score', data=results)\n", + "plt.show()\n", + "sns.scatterplot(x='param_C', y='mean_test_score', hue='param_l1_ratio', data=results)\n", + "plt.xscale('log')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notes/10_24_ new data, baseline.pdf b/notes/10_24_ new data, baseline.pdf new file mode 100644 index 0000000..3317689 Binary files /dev/null and b/notes/10_24_ new data, baseline.pdf differ