-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
71 lines (55 loc) · 2.27 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Make inference using trained model
"""
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
import os
import re
from PIL import Image
from BioFaceNet import BioFaceNet
# argument parsing
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default="utils/test_img.png", help='image filepath for model input')
parser.add_argument('--model_dir', type=str, default="checkpoints/", help='image filepath for model input')
parser.add_argument('--epoch', type=int, default=-1, help='an int, specifies which epoch to use under checkpoints/')
parser.add_argument('--output_path', type=str, default="predicted_output/", help='directory for saving output maps')
args = parser.parse_args()
def predict(args):
# auto enable gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load input image as PIL Image
image = Image.open(args.image_path).convert('RGB').resize((64, 64))
# convert to [0...1] as floats
image = torchvision.transforms.functional.to_tensor(image)[None, ...]
image.to(device)
print(image.shape)
# init model
model = BioFaceNet(device=device)
model.eval()
model.to(device)
# get which saved model to use, default use last epoch
model_list = os.listdir(args.model_dir)
epoch_list = [int(re.findall(r'[0-9]+', name)[0]) for name in model_list] # *****_20.pt, get epoch numbers as list
last_epoch = sorted(epoch_list)[-1]
print(last_epoch)
if args.epoch == -1:
epoch_to_use = last_epoch
else:
epoch_to_use = args.epoch
assert args.epoch <= last_epoch
model_base_name = "model_checkpoint_" + str(epoch_to_use) + ".pt"
model_path = os.path.join(args.model_dir, model_base_name)
print("Using trained model: {}".format(model_path))
# load weights (though this loaded more things)
model.load_state_dict(torch.load(model_path)['state_dict'])
# infer
fmel, fblood, shading, specular, b, lighting_params = model(image)
# decode
appearance, shading, specular, _ = model.decode(fmel, fblood, shading, specular, b, lighting_params)
model.visualize_output(image, appearance, shading, specular, fmel, fblood)
if __name__ == '__main__':
predict(args)