-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtimm_convmixer.py
160 lines (132 loc) · 4.7 KB
/
timm_convmixer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
As the name suggests this module is a copy of the convmixer module from
[timm](https://github.com/rwightman/pytorch-image-models) with a couple of
changes specially to allow converting the model. This updates the way the
model handles padding in a `nn.Conv2D` layer to allow it to be converted.
"""
from functools import reduce
from operator import __add__
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.registry import register_model
from .helpers import build_model_with_cfg
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.96,
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
"classifier": "head",
"first_conv": "stem.0",
**kwargs,
}
default_cfgs = {
"convmixer_1536_20": _cfg(
url="https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar"
),
"convmixer_768_32": _cfg(
url="https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar"
),
"convmixer_1024_20_ks9_p14": _cfg(
url="https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar"
),
}
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class Conv2dSamePadding(nn.Conv2d):
def __init__(self, *args, **kwargs):
super(Conv2dSamePadding, self).__init__(*args, **kwargs)
self.zero_pad_2d = nn.ZeroPad2d(
reduce(
__add__,
[
(k // 2 + (k - 2 * (k // 2)) - 1, k // 2)
for k in self.kernel_size[::-1]
],
)
)
def forward(self, input):
return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
class ConvMixer(nn.Module):
def __init__(
self,
dim,
depth,
kernel_size=9,
patch_size=7,
in_chans=3,
num_classes=1000,
activation=nn.GELU,
**kwargs
):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
activation(),
nn.BatchNorm2d(dim),
)
self.blocks = nn.Sequential(
*[
nn.Sequential(
Residual(
nn.Sequential(
Conv2dSamePadding(dim, dim, kernel_size, groups=dim),
activation(),
nn.BatchNorm2d(dim),
)
),
nn.Conv2d(dim, dim, kernel_size=1),
activation(),
nn.BatchNorm2d(dim),
)
for i in range(depth)
]
)
self.pooling = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.pooling(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_convmixer(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs
)
@register_model
def convmixer_1536_20(pretrained=False, **kwargs):
model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
return _create_convmixer("convmixer_1536_20", pretrained, **model_args)
@register_model
def convmixer_768_32(pretrained=False, **kwargs):
model_args = dict(
dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs
)
return _create_convmixer("convmixer_768_32", pretrained, **model_args)
@register_model
def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs):
model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
return _create_convmixer("convmixer_1024_20_ks9_p14", pretrained, **model_args)