forked from HypoX64/DeepMosaics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloadmodel.py
executable file
·73 lines (64 loc) · 2.63 KB
/
loadmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from . import model_util
from .pix2pix_model import define_G as pix2pix_G
from .pix2pixHD_model import define_G as pix2pixHD_G
# from .video_model import MosaicNet
# from .videoHD_model import MosaicNet as MosaicNet_HD
from .BiSeNet_model import BiSeNet
from .BVDNet import define_G as video_G
def show_paramsnumber(net,netname='net'):
parameters = sum(param.numel() for param in net.parameters())
parameters = round(parameters/1e6,2)
print(netname+' parameters: '+str(parameters)+'M')
def pix2pix(opt):
# print(opt.model_path,opt.netG)
if opt.netG == 'HD':
netG = pix2pixHD_G(3, 3, 64, 'global' ,4)
else:
netG = pix2pix_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
return netG
def style(opt):
if opt.edges:
netG = pix2pix_G(1, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
else:
netG = pix2pix_G(3, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=False, init_type='normal', gpu_ids=[])
#in other to load old pretrain model
#https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/models/base_model.py
if isinstance(netG, torch.nn.DataParallel):
netG = netG.module
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(opt.model_path, map_location='cpu')
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
model_util.patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
netG.load_state_dict(state_dict)
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
return netG
def video(opt):
netG = video_G(N=2,n_blocks=4,gpu_id=opt.gpu_id)
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
return netG
def bisenet(opt,type='roi'):
'''
type: roi or mosaic
'''
net = BiSeNet(num_classes=1, context_path='resnet18',train_flag=False)
show_paramsnumber(net,'segment')
if type == 'roi':
net.load_state_dict(torch.load(opt.model_path))
elif type == 'mosaic':
net.load_state_dict(torch.load(opt.mosaic_position_model_path))
net = model_util.todevice(net,opt.gpu_id)
net.eval()
return net