Skip to content

Commit

Permalink
Parameterize alpha and beta and fix small arch bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
rien333 committed Jun 4, 2018
1 parent 6b00ccb commit 76210d2
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 28 deletions.
151 changes: 151 additions & 0 deletions HybridDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import numpy as np
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import SOSDataset
import SynDataset
import cv2
import random

class HybridDataset(Dataset):

# maybe have two different transforms for the datasets
# Consider the dynamic between test and train (syn are totally random, and pseuod generated)
# as syn examples are pseudo generated, consider importing way more
# so yeah def use all files when this works!

def __init__(self, epochs, transform=None, grow_f=0.38, train=True,):
"""
grow_f is a factor [0,1] by how much we should grow the datasize with synthetic examples
"""

self.train = train
self.classes = 5
# MAKE EXTENDED 🐝🐝🐝
self.sos = SOSDataset.SOSDataset(train=train, extended=True, transform=transform)
self.sos_sort = self.sos.sorted_classes()
self.sos_n = [len(s) for s in self.sos_sort]
self.r_samples = len(self.sos)
s_samples = round(self.r_samples * grow_f)
self.nsamples = self.r_samples + s_samples
# number of examples per class for perfect balance
self.class_n = self.nsamples / self.classes
# generate only training examples
self.syn = SynDataset.SynDataset(train=True, transform=transform, split=1)
self.syn_sort = self.syn.sorted_classes()
self.idx = 0
self.ridx = 0
# wait this dependent on samples right
# although you might want to control these seperately
self.t_incr = 1/(epochs+1)
self.t = 0
# curve bending, slow fade in
self.u1 = -0.03 # bezier steepness in the beginning (flat 0 at start if negative)
self.u2 = 0.03 # bezier steepness towards the end
# idk, but the imbalance ratio might not work because #c1 > class_n (ratio of 1.6)
imbalance_ratio = np.clip(np.array([n / self.class_n for n in self.sos_n]), 0, 0.96)
sum_w = np.sum(imbalance_ratio)
# Normalize
imbalance_weights = (imbalance_ratio / sum_w)
# class_w are ordered on most to least samples in respect to the size of their weights, so reverse
self.imbalance_weights = imbalance_weights
ws = (self.update_class_weights(1))**(10)

self.sort_map = [list(ws).index(w) for w in sorted(ws)]
print(self.sort_map)
self.class_w_cum = np.cumsum(np.ones(5, dtype=np.float)) # weigh equally if only synthetic data
# self.class_w_cum = np.cumsum(sorted(ws/np.sum(ws))) # weigh equally if only synthetic data
# self.class_w_cum = np.cumsum(sorted(ws)) # weigh equally if only synthetic data
self.syn_ratio = self.__bezier(self.t, self.u1, self.u2)
# self.syn_ratio = 1 # self.t = 1
from collections import Counter
self.syn_counter = Counter()

def __bezier(self, t, u1, u2):
# instead of nsamples use len(self)? or update self.nsamples in the len function occiasinaly?
# u0 = 0.0 # fixed
# u3 = 1.0 # fixed
# see bezier.py to graph stuff and some extra settings?
return max(0, min(1, (3*u1*((1-t)**2))*t+3*u2*(1-t)*(t**2)+t**3))

def update_class_weights(self, syn_ratio):
# Bring the weights from being the same increasling close to reflecting the imbalance
# instead of one try the cumsum (or normalize the weights)
return 1 - (self.imbalance_weights * syn_ratio)

# samples a class according to the imbalance in the dataset with weighted distributions
def balanced_sample(self):
interval = np.random.uniform() * self.class_w_cum[-1]
i = np.searchsorted(self.class_w_cum, interval)
# i is the index of the weight corresponding to the class that needs to be generated
return self.sort_map[i]

def synthetic_sample(self):
# if synthetic, sample to composotate for the class imbalance
c = self.balanced_sample()
# Sample a random image from that class
s_idx = random.choice(self.syn_sort[c])
return self.syn[s_idx]

# Maybe make a choise if you want a balanced sample from the real data as well
def __getitem__(self, index):
r = np.random.uniform()
# not to sure about the grow f here
if (r < self.syn_ratio) and self.ridx < self.r_samples:
# sample real
# consider calling balanced sample here as well
s = self.sos[self.ridx]
self.ridx += 1 # a teranry check is less expensive than a modulo, although wraapping might be desired
else:
s = self.synthetic_sample()
self.syn_counter[s[1]] += 1


self.idx += 1
if self.idx >= self.nsamples:
self.t += self.t_incr
print("t", self.t)
n_syn_ratio = self.__bezier(self.t, self.u1, self.u2)
print("new ratio: ", n_syn_ratio)
n_ws = (self.update_class_weights(n_syn_ratio))**(self.syn_ratio*10)
# normalize
n_ws /= np.sum(n_ws)
self.class_w_cum = np.cumsum(sorted(n_ws))
print("new weights", n_ws)
print("sorted cum", self.class_w_cum)
self.syn_ratio = n_syn_ratio
# self.ridx = 0 # Uncomment!
self.idx = 0


return s

# We will probably update this dynamically, but keep it in sync with the batch size! (nice and divedable?)
def __len__(self):
return self.nsamples

if __name__ == "__main__":
from collections import Counter
t = [SOSDataset.Rescale((232, 232)), SOSDataset.RandomCrop((SOSDataset.DATA_W, SOSDataset.DATA_H))]
epochs=20
hd = HybridDataset(epochs=epochs, transform=t, train=True, grow_f=0.50)
samples = len(hd)
for epoch in range(epochs+2):
classes = Counter()
for s in range(samples):
try:
classes[int(hd[s][1])] += 1
except:
print("Incorrect class label I think?")
print("idx", s)
print("idx", hd[s], "len", len(hd[s]))

print("Real examples:", hd.ridx)
hd.ridx = 0 # remove!
print("Syn", sorted(hd.syn_counter.items(), key=lambda pair: pair[0], reverse=False))
hd.syn_counter = Counter()
print("All", sorted(classes.items(), key=lambda pair: pair[0], reverse=False))
print("------------------------------")
# exit(0)

7 changes: 1 addition & 6 deletions HybridEqualDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ class HybridEqualDataset(Dataset):
# Consider the dynamic between test and train (syn are totally random, and pseuod generated)
# as syn examples are pseudo generated, consider importing way more

<<<<<<< HEAD

def __init__(self, epochs, transform=None, grow_f=0.38, t=0.0, datadir="../Datasets/", sorted_loc="/tmp",
syn_samples=None, real_samples=None, train=True,):
=======
def __init__(self, epochs, transform=None, grow_f=0.38, t=0.0, datadir="../Datasets/", sorted_loc="/tmp/",
syn_samples=[], real_samples=[], train=True,):
>>>>>>> f9b59aff8d3c5196ea35b34335fdb9431d531c87
"""
grow_f is a factor [0,1] by how much we should grow the datasize with synthetic examples
"""
Expand Down Expand Up @@ -120,7 +116,6 @@ def __len__(self):
# syn_samples = [4700, 5400, 8023, 8200, 8700]
# real_samples = [1101, 1100, 1604, 1058, 853]


hd = HybridEqualDataset(epochs=epochs, transform=t, train=True, t=0.775, grow_f=6.2952)

samples = len(hd)
Expand Down
3 changes: 2 additions & 1 deletion SOSDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __init__(self, train=True, transform=None, datadir="../Datasets/", sorted_l
self.transform_name = ''.join([t.__class__.__name__ for t in transform])
else:
self.transform = None
self.sorted_loc = sorted_loc + "sorted_classes_sos_" + str(self.train)+".pickle"
self.sorted_loc = sorted_loc + "/sorted_classes_sos_" + str(self.train)+".pickle"

# Read in the .mat file
if extended:
import scipy.io as sio
Expand Down
10 changes: 7 additions & 3 deletions SynDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class SynDataset(Dataset):

def __init__(self, train=True, transform=None, datadir="../Datasets/", sorted_loc="/tmp/", n=None, split=0.8):
def __init__(self, train=True, transform=None, datadir="../Datasets/", sorted_loc="/tmp", n=None, split=0.8):
self.datadir = datadir
self.train = train
self.transform = transforms.Compose(transform)
Expand All @@ -28,7 +28,7 @@ def __init__(self, train=True, transform=None, datadir="../Datasets/", sorted_lo
nfiles = len(self.files)
self.train_range = int(split * nfiles) # convert to idx
self.nsamples = self.train_range if train else nfiles - self.train_range
self.sorted_loc = sorted_loc + "sorted_classes_syn.pickle"
self.sorted_loc = sorted_loc + "/sorted_classes_syn.pickle"
if os.path.isfile(self.sorted_loc):
with open (self.sorted_loc, 'rb') as f:
sorted_classes = pickle.load(f)
Expand All @@ -41,7 +41,11 @@ def __len__(self):
def __getitem__(self, index):
start = 0 if self.train else self.train_range
im_name = self.datadir + self.files[start+index]
im = cv2.cvtColor(cv2.imread(im_name), cv2.COLOR_BGR2RGB)
try:
im = cv2.cvtColor(cv2.imread(im_name), cv2.COLOR_BGR2RGB)
except:
print(im_name)
exit(0)
label = int(im_name[-5])
return self.transform((im, label))

Expand Down
49 changes: 31 additions & 18 deletions conv_vae_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
help='enables CUDA training')
parser.add_argument('--z-dims', type=int, default=20, metavar='N',
help='dimenstionality of the latent z variable')
parser.add_argument('--alpha', type=float, default=1.0, metavar='N',
help='Weight of KDL loss')
parser.add_argument('--beta', type=float, default=0.25, metavar='N',
help='Weight of content loss')
parser.add_argument('--dfc', action='store_true', default=False, help="Train with deep feature consistency loss")
parser.add_argument('--full-con-size', type=int, default=400, metavar='N',
help='size of the fully connected layer')
Expand All @@ -39,8 +43,7 @@
help='epoch to start at (only affects logging)')
parser.add_argument('--test-interval', type=int, default=10, metavar='N',
help='when to run a test epoch')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')

Expand Down Expand Up @@ -233,7 +236,8 @@ def __init__(self):
self.fc4 = nn.Sequential(
# nn.Linear(args.full_con_size, int(np.prod(self.deconv_shape))),
nn.Linear(args.z_dims, int(np.prod(self.deconv_shape))),
nn.LeakyReLU(0.01), # Some people use normal relu here
# nn.LeakyReLU(0.01), # Some people use normal relu here
nn.ReLU(), # Some people use normal relu here
nn.BatchNorm1d(int(np.prod(self.deconv_shape))) # unneeded?
)

Expand Down Expand Up @@ -399,7 +403,8 @@ def __init__(self):
self.norm_layer = ImageNet_Norm_Layer_2() # norm done in net to net screw the input

# ngpu = torch.cuda.device_count()
self.ngpu = ngpu if not lisa_check else 1 # too much mem # assign to ngpu
# self.ngpu = ngpu if not lisa_check else 1 # too much mem # assign to ngpu
self.ngpu = ngpu
if self.ngpu > 1:
# Functional equivalent of below (idkkk if this is problematic? maybe it's good)
self.gpu_func = lambda module, output: nn.parallel.data_parallel(module, output, range(self.ngpu))
Expand All @@ -413,12 +418,18 @@ def __init__(self):
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
# Add one?
self.content_layers = ['relu1_1', 'relu2_1', 'relu3_1',]
content_layers = ['relu1_1', 'relu2_1', 'relu3_1',]
self.content_layers = list(content_layers)

self.features = nn.Sequential()
for i, module in enumerate(features):
name = self.layer_names[i]
self.features.add_module(name, module)
self.features.add_module(name, module)
if name in content_layers:
content_layers.remove(name)
if not content_layers:
# Stop adding stuff
break

def forward(self, x):
batch_size = x.size(0) # needed because sometimes the batch size is not exactly args.batch_size
Expand All @@ -430,9 +441,6 @@ def forward(self, x):
output = self.gpu_func(module, output)
if name in self.content_layers:
all_outputs.append(output.view(batch_size, -1))
# visited += 1
# if visited >= visits:
# break
return all_outputs

# this has to be a trainable module for some reason
Expand All @@ -446,9 +454,11 @@ def __init__(self, alpha=1, beta=0.5):
self.criterion = nn.MSELoss(size_average=False)

def forward(self, output, target, mean, logvar):
# people use sum here instead of mean (dfc authors/versus standard pytorch sum implementation) 🌸
# sum seems to have weird graphical glitches?
kld = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
# Note detach
loss_list = [self.criterion(output[layer], target[layer].detach()) for layer in range(len(output))]
# Note detach for target
loss_list = [self.criterion(output[layer], target[layer]) for layer in range(len(output))]
content = sum(loss_list)
return self.alpha * kld + self.beta * content

Expand Down Expand Up @@ -478,7 +488,7 @@ def weights_init(m):
for param in descriptor.parameters():
param.requires_grad = False

content_loss = Content_Loss(alpha=1.0, beta=0.5)
content_loss = Content_Loss(alpha=args.alpha, beta=args.beta)
content_loss.to(device)

if ngpu > 1:
Expand Down Expand Up @@ -529,11 +539,12 @@ def loss_function_dfc_split(recon_x, x, mu, logvar):
targets = descriptor(x) # vgg
recon_features = descriptor(recon_x)
# BCE = F.binary_cross_entropy(recon_x, x)
# Note the mean versus sum thing also mentioned above 🌸
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss_list = [F.mse_loss(recon_features[layer], targets[layer].detach(), size_average=False) for layer in range(len(recon_features))]
# note the detach
loss_list = [F.mse_loss(recon_features[layer], targets[layer], size_average=False) for layer in range(len(targets))]
content = sum(loss_list)

return 1.0*kld, 0.5*content,
return args.alpha*kld, args.beta*content,
# return content_loss(recon_features, targets, mu, logvar)

# Check for dfc loss
Expand Down Expand Up @@ -584,7 +595,7 @@ def test(epoch, loader):

# ~50 sets of random ZDIMS-float vectors to images
# Weird hack bc this is drawn from ~ N(0, 1), and our distribution looks different
sample = torch.randn(49, args.z_dims).to(device) * 1.1
sample = torch.randn(49, args.z_dims).to(device) * 4.2

if ngpu > 1:
sample = model.module.decode(sample)
Expand Down Expand Up @@ -637,22 +648,24 @@ def train_routine(epochs, train_loader, test_loader, optimizer, scheduler, reset
if __name__ == "__main__":

grow_f=6.2952 # Lisa size
# grow_f=6.2952/4 # Lisa size
# grow_f=3.5032
hybrid_train_loader = torch.utils.data.DataLoader(
HybridEqualDataset.HybridEqualDataset(epochs=args.epochs-6, train=True, transform=data_transform,
t=0.775,grow_f=grow_f, datadir=DATA_DIR, sorted_loc=SORT_DIR),
t=0.605,grow_f=grow_f, datadir=DATA_DIR, sorted_loc=SORT_DIR),
batch_size=args.batch_size, shuffle=True, **kwargs)

hybrid_test_loader = torch.utils.data.DataLoader(
HybridEqualDataset.HybridEqualDataset(epochs=args.epochs-6, train=False, transform=data_transform,
t=0.775,grow_f=2.0, datadir=DATA_DIR, sorted_loc=SORT_DIR),
t=0.605,grow_f=2.0, datadir=DATA_DIR, sorted_loc=SORT_DIR),
batch_size=args.batch_size, shuffle=True, **kwargs)

# # optimizer = optim.Adam(model.parameters(), lr=1e-3) # = 0.001
optimizer = optim.Adam(model.parameters(), lr=0.00135)
# Decay lr if nothing happens after 4 epochs (try 3?)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.23, patience=4, cooldown=1,
verbose=True)

train_routine(args.epochs, train_loader=hybrid_train_loader, test_loader=hybrid_test_loader,
optimizer=optimizer, scheduler=scheduler, reset=102)

Expand Down

0 comments on commit 76210d2

Please sign in to comment.