Skip to content

Commit

Permalink
refine models
Browse files Browse the repository at this point in the history
  • Loading branch information
quqixun committed Sep 29, 2022
1 parent 16bcf58 commit b3b9735
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 431 deletions.
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def main(args):

# initializes trainer
if args.evaluator == 'basic':
evaluator = BCIBasicEvaluator(configs, model_path, apply_tta)
evaluator = BCIEvaluatorBasic(configs, model_path, apply_tta)
elif args.evaluator == 'cahr':
evaluator = BCICAHREvaluator(configs, model_path, apply_tta)
evaluator = BCIEvaluatorCAHR(configs, model_path, apply_tta)

# generates predictions
evaluator.forward(args.data_dir, output_dir)
Expand Down
2 changes: 1 addition & 1 deletion libs/evaluate/evaluator_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..utils import normalize_image, unnormalize_image, tta, untta


class BCIBasicEvaluator(object):
class BCIEvaluatorBasic(object):

def __init__(self, configs, model_path, apply_tta=False):

Expand Down
2 changes: 1 addition & 1 deletion libs/evaluate/evaluator_cahr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..utils import normalize_image, unnormalize_image, tta, untta


class BCICAHREvaluator(object):
class BCIEvaluatorCAHR(object):

def __init__(self, configs, model_path, apply_tta=False):

Expand Down
49 changes: 16 additions & 33 deletions libs/models/C.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@

def define_C(configs):

if configs.name == 'simple':
net = SimpleComparator(**configs.params)
elif configs.name == 'mobilenetv2':
net = MobileNetV2Comparator(**configs.params)
elif configs.name == 'shufflenetv2':
net = ShuffleNetV2Comparator(**configs.params)
if configs.name == 'comparator_basic':
net = ComparatorBasic(**configs.params)
elif configs.name == 'comparator_mobilenetv2':
net = ComparatorMobileNetV2(**configs.params)
elif configs.name == 'comparator_shufflenetv2':
net = ComparatorShuffleNetV2(**configs.params)
else:
raise NotImplementedError(f'unknown C model name {configs.name}')

init_weights(net, **configs.init)
return net


class SimpleComparator(nn.Module):
class ComparatorBasic(nn.Module):

def __init__(self,
full_size=1024,
Expand All @@ -31,9 +31,10 @@ def __init__(self,
max_channels=256,
levels=4,
norm_type='batch',
dropout=0.2
dropout=0.2,
attention=False
):
super(SimpleComparator, self).__init__()
super(ComparatorBasic, self).__init__()

assert norm_type in ['batch', 'instance', 'none']
norm_layer = get_norm_layer(norm_type=norm_type)
Expand All @@ -44,7 +45,7 @@ def __init__(self,
in_dims=input_channels, out_dims=init_channels,
conv_type='conv2d', kernel_size=7, stride=2,
padding=3, bias=use_bias, norm_layer=norm_layer,
sampling='none'
sampling='none', attention=False
)
]

Expand All @@ -58,7 +59,7 @@ def __init__(self,
in_dims=in_dims, out_dims=out_dims,
conv_type='conv2d', kernel_size=3, stride=2,
padding=1, bias=use_bias, norm_layer=norm_layer,
sampling='none'
sampling='none', attention=attention
)
)
encoder.append(nn.AdaptiveAvgPool2d(1))
Expand All @@ -77,15 +78,15 @@ def forward(self, x):
return levels, latent


class MobileNetV2Comparator(nn.Module):
class ComparatorMobileNetV2(nn.Module):

def __init__(self,
levels=4,
width_mult=1.0,
norm_type='batch',
dropout=0.2
):
super(MobileNetV2Comparator, self).__init__()
super(ComparatorMobileNetV2, self).__init__()

assert norm_type in ['batch', 'instance', 'none']
norm_layer = get_norm_layer(norm_type=norm_type)
Expand All @@ -111,10 +112,10 @@ def forward(self, x):
return levels, latent


class ShuffleNetV2Comparator(nn.Module):
class ComparatorShuffleNetV2(nn.Module):

def __init__(self, levels=4, width_mult=1.0):
super(ShuffleNetV2Comparator, self).__init__()
super(ComparatorShuffleNetV2, self).__init__()
assert width_mult in [0.5, 1.0, 1.5, 2.0]

if width_mult == 0.5:
Expand Down Expand Up @@ -148,21 +149,3 @@ def forward(self, x):
latent = latent.mean([2, 3]) # globalpool
levels = self.classify_head(latent)
return levels, latent


if __name__ == '__main__':

# model = MobileNetV2Comparator(
# levels=4,
# width_mult=1.0,
# norm_type='batch',
# dropout=0.2
# ).cuda()

model = MobileNetV2Comparator(levels=4, width_mult=1.0).cuda()

print(model)

x = torch.rand(2, 3, 1024, 1024).cuda()
levels, latent = model(x)
print(levels.size(), latent.size())
35 changes: 6 additions & 29 deletions libs/models/D.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,14 @@ def __init__(self, input_channels, init_channels=64,

norm_layer = get_norm_layer(norm_type=norm_type)
use_bias = False if norm_type == 'batch' else True
# if type(norm_layer) == functools.partial:
# use_bias = norm_layer.func == nn.InstanceNorm2d
# else:
# use_bias = norm_layer == nn.InstanceNorm2d
kw, padw = 4, 1

kw = 4
padw = 1
sequence = [
# nn.Conv2d(
# input_channels, init_channels,
# kernel_size=kw, stride=2, padding=padw
# ),
# norm_layer(init_channels),
# nn.LeakyReLU(0.2, True)
ConvNormAct(
in_dims=input_channels, out_dims=init_channels,
conv_type='conv2d', kernel_size=3, stride=2, padding=1,
conv_type='conv2d', kernel_size=kw, stride=2, padding=padw,
bias=use_bias, norm_layer=norm_layer, sampling='none',
attention=attention
attention=False
)
]

Expand All @@ -91,15 +80,9 @@ def __init__(self, input_channels, init_channels=64,
in_dims = init_channels * nf_mult_prev
out_dims = init_channels * nf_mult
sequence += [
# nn.Conv2d(
# init_channels * nf_mult_prev, init_channels * nf_mult,
# kernel_size=kw, stride=2, padding=padw, bias=use_bias
# ),
# norm_layer(init_channels * nf_mult),
# nn.LeakyReLU(0.2, True)
ConvNormAct(
in_dims=in_dims, out_dims=out_dims,
conv_type='conv2d', kernel_size=3, stride=2, padding=1,
conv_type='conv2d', kernel_size=kw, stride=2, padding=padw,
bias=use_bias, norm_layer=norm_layer, sampling='none',
attention=attention
)
Expand All @@ -110,23 +93,17 @@ def __init__(self, input_channels, init_channels=64,
in_dims = init_channels * nf_mult_prev
out_dims = init_channels * nf_mult
sequence += [
# nn.Conv2d(
# init_channels * nf_mult_prev, init_channels * nf_mult,
# kernel_size=kw, stride=1, padding=padw, bias=use_bias
# ),
# norm_layer(init_channels * nf_mult),
# nn.LeakyReLU(0.2, True)
ConvNormAct(
in_dims=in_dims, out_dims=out_dims,
conv_type='conv2d', kernel_size=3, stride=1, padding=1,
conv_type='conv2d', kernel_size=kw, stride=1, padding=padw,
bias=use_bias, norm_layer=norm_layer, sampling='none',
attention=attention
)
]

sequence += [
nn.Conv2d(
init_channels * nf_mult, 1,
out_dims, 1,
kernel_size=kw, stride=1, padding=padw
)
]
Expand Down
Loading

0 comments on commit b3b9735

Please sign in to comment.