Skip to content

Commit

Permalink
re-finished dataset keeping the data in RAM
Browse files Browse the repository at this point in the history
  • Loading branch information
guidopetri committed Oct 25, 2020
1 parent 7765fd3 commit 2c38adf
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions csgo_wp/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,18 @@ def __init__(self,

print(f'Found {bad_round_count} rounds with fewer than 10 players')

self.transform = transform

for k, v in splits.items():
with open(self.folder + k + f'/{k}.pckl', 'wb') as f:
with open(folder + k + f'/{k}.pckl', 'wb') as f:
pickle.dump(v, f)

self.raw_data = splits[self.split]
del splits
else:
with open(f'{self.folder}{self.split}/{self.split}.pckl',
with open(f'{folder}{self.split}/{self.split}.pckl',
'rb') as f:
self.raw_data = pickle.load(f)

self.transform = transform
self.data = []
self.targets = []

Expand All @@ -229,24 +228,28 @@ def __init__(self,

print('Transforming raw data...')

for game_round in self.raw_data:
transformed = self.transform(game_round, 'de_dust2')
self.data.extend(transformed)
len_data = len(self.raw_data)

for idx, game_round in enumerate(self.raw_data):
match_id = game_round['MatchId'].values[0]
map_name = game_round['MapName'].values[0]
round_num = game_round['RoundNum'].values[0]
print(f'\rTransforming {idx +1}/{len_data}: {match_id}, '
f'{map_name}, {round_num}', end='')
transformed = self.transform(game_round, 'de_dust2')
self.data.extend(transformed)

target = self.rounds[(self.rounds['MatchId'] == match_id)
& (self.rounds['MapName'] == map_name)
& (self.rounds['RoundNum'] == round_num)]
target = 1 if target['WinningSide'] == 'CT' else 0
self.targets.extend([target for _ in range(transformed.shape[0])])
target = 1 if target['WinningSide'].iloc[0] == 'CT' else 0
self.targets.extend([target
for _ in range(transformed.shape[0])])

self.data = torch.Tensor(self.data)
self.data = torch.stack(self.data)
self.targets = torch.Tensor(self.targets)

print('Done!')
print('\nDone!')

def __len__(self):
return self.data.shape[0]
Expand Down

0 comments on commit 2c38adf

Please sign in to comment.