-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify model.py
88 lines (74 loc) · 3.1 KB
/
classify model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from torch import nn
# Unet model
class pub(nn.Module):
def __init__(self, in_channels, out_channels, batch_norm=True):
super(pub, self).__init__()
inter_channels = in_channels/2 if in_channels > out_channels else out_channels/2
layers = [
nn.Conv3d(in_channels, inter_channels, 3, stride=1, padding=0),
nn.ReLU(True),
nn.Conv3d(inter_channels, out_channels, 3, stride=1, padding=0),
nn.ReLU(True)
]
if batch_norm:
layers.insert(1, nn.BatchNorm3d(inter_channels))
layers.insert(len(layers)-1, nn.BatchNorm3d(out_channels))
self.pub = nn.Sequential(*layers)
def forward(self, x):
return self.pub(x)
class unet3dDown(nn.Module):
def __init__(self, in_channels, out_channels, batch_norm=True):
super(unet3dDown, self).__init__()
self.pub = pub(in_channels, out_channels, batch_norm)
self.pool = nn.MaxPool3d(2, stride=2)
def forward(self, x):
x = self.pool(x)
x = self.pub(x)
return x
class unet3dUp(nn.Module):
def __init__(self, in_channels, out_channels, batch_norm=True, sample=True):
super(unet3dUp, self).__init__()
self.pub = pub(in_channels/2+in_channels, out_channels, batch_norm)
if sample:
self.sample = nn.Upsample(scale_factor=2, mode='nearest')
else:
self.sample = nn.ConvTranspose3d(in_channels, in_channels, 2, stride=2)
def forward(self, x, x1):
x = self.sample(x)
c1 = (x1.size(2) - x.size(2)) // 2
c2 = (x1.size(3) - x.size(3)) // 2
x1 = x1[:, :, c1:-c1, c2:-c2, c2:-c2]
x = torch.cat((x, x1), dim=1)
x = self.pub(x)
return x
class unet3d(nn.Module):
def __init__(self, init_channels=1, class_nums=1, batch_norm=True, sample=True):
super(unet3d, self).__init__()
self.down1 = pub(init_channels, 64, batch_norm)
self.down2 = unet3dDown(64, 128, batch_norm)
self.down3 = unet3dDown(128, 256, batch_norm)
self.down4 = unet3dDown(256, 512, batch_norm)
self.up3 = unet3dUp(512, 256, batch_norm, sample)
self.up2 = unet3dUp(256, 128, batch_norm, sample)
self.up1 = unet3dUp(128, 64, batch_norm, sample)
self.con_last = nn.Conv3d(64, class_nums, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.up3(x4, x3)
x = self.up2(x, x2)
x = self.up1(x, x1)
x = self.con_last(x)
return self.sigmoid(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()