-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
150 lines (119 loc) · 5.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
train BioFaceNet using part of CelebA dataset
"""
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from torch.optim import SGD
from datasets.celebA import CelebADataLoader
from BioFaceNet import BioFaceNet
from loss import loss
# argument parsing
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train')
parser.add_argument('--show', action='store_true', help='if enabled, plot 5 sample images')
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
parser.add_argument('--test_forward', action='store_true', help='if enabled, test forward pass by feeding 1 image')
parser.add_argument('--viz', action='store_true', help='if enabled, images of target/model output will be plotted every batch')
parser.add_argument('--save_dir', type=str, default="checkpoints/", help='directory for saving trained model')
parser.add_argument('--data_dir', type=str, default="data/", help='directory for training datasets')
args = parser.parse_args()
def train(args):
# make directory for checkpoints saving
os.makedirs(args.save_dir, exist_ok=True)
# auto enable gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# data path lists
inmc_list = [
'zx_7_d10_inmc_celebA_20.hdf5',
# 'zx_7_d10_inmc_celebA_05.hdf5',
]
lrgb_list = [
'zx_7_d3_lrgb_celebA_20.hdf5',
# 'zx_7_d3_lrgb_celebA_05.hdf5',
]
# inserting data dir in the front of filename
inmc_list = [os.path.join(args.data_dir, fn) for fn in inmc_list]
lrgb_list = [os.path.join(args.data_dir, fn) for fn in lrgb_list]
# trainin dataloader
train_loader = CelebADataLoader(inmc_list, lrgb_list).loader
# network
model = BioFaceNet(device=device)
# optimizer
optim = SGD(
model.parameters(),
lr=args.lr
)
# training
for epoch in range(args.epochs):
with tqdm(train_loader, unit="batch") as tepoch:
for batch in tepoch:
# get batch items
image, normal, mask, actual_shading, spherical_harmonics_param = batch
# forward pass
fmel, fblood, pred_shading, pred_specular, b, lighting_params = model(image)
# decode (physcial model based)
appearance, pred_shading, pred_specular, b = model.decode(fmel, fblood, pred_shading, pred_specular, b, lighting_params)
# visualize training progress
if args.viz:
model.visualize_training_progress(image, actual_shading, mask, appearance, pred_shading, pred_specular, fmel, fblood)
# pack predicted items for loss computation
predicts = {
'appearance': appearance,
'b': b,
'specular': pred_specular,
'shading': pred_shading
}
targets = {
'appearance': image,
'shading': actual_shading,
'mask': mask
}
# compute loss
batch_loss = loss(predicts, targets)
# reset optimizer & backprop
optim.zero_grad()
batch_loss.backward()
optim.step()
# update info
tepoch.set_postfix(epoch="{}/{}".format(epoch+1, args.epochs), loss=batch_loss.cpu().detach().numpy())
# save model each epoch
ckpt_filename = "model_checkpoint_{}.pt".format(epoch)
save_path = os.path.join(args.save_dir, ckpt_filename)
state = {
'epoch':epoch,
'state_dict':model.state_dict(), # use model.load_state_dict(torch.load(XX)) when resume training
}
torch.save(state, save_path)
if __name__ == '__main__':
train(args)
if args.show:
CelebADataLoader().show_sample()
# Test forward pass
if args.test_forward:
# data path lists
inmc_list = [
'data/zx_7_d10_inmc_celebA_20.hdf5',
# 'data/zx_7_d10_inmc_celebA_05.hdf5',
]
lrgb_list = [
'data/zx_7_d3_lrgb_celebA_20.hdf5',
# 'data/zx_7_d3_lrgb_celebA_05.hdf5',
]
# trainin dataloader
train_loader = CelebADataLoader(inmc_list, lrgb_list).loader
# network
model = BioFaceNet()
for batch in train_loader:
image, normal, mask, actual_shading, spherical_harmonics_param = batch
image, normal, mask, actual_shading, spherical_harmonics_param = image[0][None,...], normal[0][None,...], mask[0][None,...], actual_shading[0][None,...], spherical_harmonics_param[0][None,...]
output = model(image)
fmel, fblood, shading, specular, b, lighting_params = output
print(fmel.shape, fblood.shape, shading.shape, specular.shape, b.shape, lighting_params.shape)
# plt.imshow(fmel[0][0].detach().numpy())
# plt.show()
print(lighting_params)
break