Skip to content

Commit

Permalink
Gan code finished!
Browse files Browse the repository at this point in the history
  • Loading branch information
HypoX64 committed Apr 24, 2021
1 parent 796b59d commit 1749be9
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 342 deletions.
4 changes: 2 additions & 2 deletions cores/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def cleanmosaic_video_fusion(opt,netG,netM):
mosaic_input[:,:,k*3:(k+1)*3] = impro.resize(img_pool[k][y-size:y+size,x-size:x+size], INPUT_SIZE)
mask_input = impro.resize(mask,np.min(img_origin.shape[:2]))[y-size:y+size,x-size:x+size]
mosaic_input[:,:,-1] = impro.resize(mask_input, INPUT_SIZE)
mosaic_input_tensor = data.im2tensor(mosaic_input,bgr2rgb=False,gpu_id=opt.gpu_id,use_transform = False,is0_1 = False)
mosaic_input_tensor = data.im2tensor(mosaic_input,bgr2rgb=False,gpu_id=opt.gpu_id)
unmosaic_pred = netG(mosaic_input_tensor)
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False ,is0_1 = False)
img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = False)
img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather)
except Exception as e:
print('Warning:',e)
Expand Down
24 changes: 8 additions & 16 deletions models/BVDNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,14 @@ def forward(self, stream, previous):

def define_G(N=2, n_blocks=1, gpu_id='-1'):
netG = BVDNet(N = N, n_blocks=n_blocks)
if gpu_id != '-1' and len(gpu_id) == 1:
netG.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
netG = nn.DataParallel(netG)
netG.cuda()
# netG.apply(model_util.init_weights)
netG = model_util.todevice(netG,gpu_id)
netG.apply(model_util.init_weights)
return netG

################################Discriminator################################
def define_D(input_nc=6, ndf=64, n_layers_D=3, use_sigmoid=False, num_D=4, gpu_id='-1'):
def define_D(input_nc=6, ndf=64, n_layers_D=1, use_sigmoid=False, num_D=3, gpu_id='-1'):
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, num_D)
if gpu_id != '-1' and len(gpu_id) == 1:
netD.cuda()
elif gpu_id != '-1' and len(gpu_id) > 1:
netD = nn.DataParallel(netD)
netD.cuda()
netD = model_util.todevice(netD,gpu_id)
netD.apply(model_util.init_weights)
return netD

Expand Down Expand Up @@ -191,16 +183,16 @@ def forward(self, dis_fake = None, dis_real = None):
if self.mode == 'D':
loss = 0
for i in range(len(dis_fake)):
loss += self.lossf(dis_fake[i][0],dis_real[i][0])
loss += self.lossf(dis_fake[i][-1],dis_real[i][-1])
elif self.mode =='G':
loss = 0
weight = 2**len(dis_fake)
for i in range(len(dis_fake)):
weight = weight/2
loss += weight*self.lossf(dis_fake[i][0])
loss += weight*self.lossf(dis_fake[i][-1])
return loss
else:
if self.mode == 'D':
return self.lossf(dis_fake[0],dis_real[0])
return self.lossf(dis_fake[-1],dis_real[-1])
elif self.mode =='G':
return self.lossf(dis_fake[0])
return self.lossf(dis_fake[-1])
6 changes: 3 additions & 3 deletions models/BiSeNet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
from . import components
from . import model_util
import warnings
warnings.filterwarnings(action='ignore')

Expand Down Expand Up @@ -43,7 +43,7 @@ def forward(self, output, target):
class resnet18(torch.nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.features = components.resnet18(pretrained=pretrained)
self.features = model_util.resnet18(pretrained=pretrained)
self.conv1 = self.features.conv1
self.bn1 = self.features.bn1
self.relu = self.features.relu
Expand All @@ -70,7 +70,7 @@ def forward(self, input):
class resnet101(torch.nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.features = components.resnet101(pretrained=pretrained)
self.features = model_util.resnet101(pretrained=pretrained)
self.conv1 = self.features.conv1
self.bn1 = self.features.bn1
self.relu = self.features.relu
Expand Down
234 changes: 0 additions & 234 deletions models/components.py

This file was deleted.

30 changes: 7 additions & 23 deletions models/loadmodel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from . import model_util
from .pix2pix_model import define_G
from .pix2pixHD_model import define_G as define_G_HD
from .unet_model import UNet
from .video_model import MosaicNet
from .videoHD_model import MosaicNet as MosaicNet_HD
from .BiSeNet_model import BiSeNet
Expand All @@ -11,19 +11,6 @@ def show_paramsnumber(net,netname='net'):
parameters = round(parameters/1e6,2)
print(netname+' parameters: '+str(parameters)+'M')

def __patch_instance_norm_state_dict(state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

def pix2pix(opt):
# print(opt.model_path,opt.netG)
Expand All @@ -33,9 +20,8 @@ def pix2pix(opt):
netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[])
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
if opt.gpu_id != -1:
netG.cuda()
return netG


Expand All @@ -57,11 +43,11 @@ def style(opt):

# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
__patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
model_util.patch_instance_norm_state_dict(state_dict, netG, key.split('.'))
netG.load_state_dict(state_dict)

if opt.gpu_id != -1:
netG.cuda()
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
return netG

def video(opt):
Expand All @@ -71,9 +57,8 @@ def video(opt):
netG = MosaicNet(3*25+1, 3,norm = 'batch')
show_paramsnumber(netG,'netG')
netG.load_state_dict(torch.load(opt.model_path))
netG = model_util.todevice(netG,opt.gpu_id)
netG.eval()
if opt.gpu_id != -1:
netG.cuda()
return netG

def bisenet(opt,type='roi'):
Expand All @@ -86,7 +71,6 @@ def bisenet(opt,type='roi'):
net.load_state_dict(torch.load(opt.model_path))
elif type == 'mosaic':
net.load_state_dict(torch.load(opt.mosaic_position_model_path))
net = model_util.todevice(net,opt.gpu_id)
net.eval()
if opt.gpu_id != -1:
net.cuda()
return net
Loading

0 comments on commit 1749be9

Please sign in to comment.