-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathutils.py
90 lines (72 loc) · 2.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import numpy as np
import glob
n_classes = 316
import json
def get_gan_data(generated_size, n_clusters=3, generated_dir=None):
assert generated_dir is not None
labels = []
# Get the labels for each gan set and save in array labels
for i in range(n_clusters):
f = open(os.path.join('/home/paul/clustering', 'gan%s.list' % i), 'r')
tmp_labels = np.zeros(shape=n_classes, dtype=np.float)
for line in f:
lbl = line.strip()
tmp_labels[int(lbl)] = 1.0
f.close()
tmp_labels = tmp_labels / np.sum(tmp_labels)
labels.append(tmp_labels)
labels = np.array(labels)
n_gan = int(np.floor(generated_size / n_clusters + 1))
data_list = None
for i in range(n_clusters):
gan_list = glob.glob(os.path.join(generated_dir, 'gan_%s*.jpg' % i))
gan_list = gan_list[:n_gan]
if data_list is None:
data_list = gan_list
else:
data_list = np.concatenate((data_list, gan_list), axis=0)
data_list = np.unique(data_list)
np.random.shuffle(data_list)
assert data_list.shape[0] >= generated_size
data_list = data_list[:generated_size]
img_labels = []
images = []
flags = []
for i, filename in enumerate(data_list):
img_name = os.path.basename(filename)
lbl = int(img_name.split('_')[1])
img_labels.append(labels[lbl])
temp = 'gen_0000' + '_' + img_name
images.append(temp)
flags.append(1)
assert len(images) == generated_size
assert len(images) == len(img_labels) == len(flags)
return images, img_labels, flags
class AverageMeter(object):
"""Computes and stores the average and current value.
Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def read_json(fpath):
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def mkdir_if_missing(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def write_json(obj, fpath):
mkdir_if_missing(os.path.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))