-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconvolutional_enhance_with_wideresnet.py
215 lines (171 loc) · 7.71 KB
/
convolutional_enhance_with_wideresnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# -*- coding: utf-8 -*-
"""convolutional_enhance_with_WideResNet.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1w74OijfDZ7diEUFZxV05vNTsprIdKXRm
"""
import os
import numpy as np
import torch
import torchvision
import tarfile
from torchvision.datasets.utils import download_url
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision import transforms as tt
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import SubsetRandomSampler
# WideResNet22 - 22 Convolutional layers: model with residual blocks and batch normalization
# Residual Block adds teh original input back to the output feature map obtained by passing the input through >1 conv layers.
# CNN Augmentation with ResNet and Regularization Techniques
# 1. Data normalization - By subtracting the mean and dividing by the standard deviation across each channel.
# As a result the mean of the data across each channel is 0, and SD is 1. This prevents the values from any one channel
# from disproportionately affecting losses and gradients while training (due to higher or wider range of values than others)
# 2. Data augmentation
# 3. Residual Connections
# 4. Batch normalization
# 5. Learning Rate annealing
# 6. Weight Decay
# 7. Gradient Clipping
dataset_url = "http://files.fast.ai/data/examples/cifar10.tgz"
download_url(dataset_url, '.')
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
tar.extractall(path='./data')
data_dir = './data/cifar10'
print(os.listdir(data_dir))
print('')
classes = os.listdir(data_dir + "/train")
print(classes)
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_tfms = tt.Compose([tt.RandomCrop(32, padding=4, padding_mode='reflect'),
tt.RandomHorizontalFlip(),
tt.ToTensor(),
tt.Normalize(*stats)])
valid_tfms = tt.Compose([tt.ToTensor(), tt.Normalize(*stats)])
train_ds = ImageFolder(data_dir+'/train', train_tfms)
valid_ds = ImageFolder(data_dir+'/test', valid_tfms)
batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size, shuffle=True, num_workers=8, pin_memory=True)
# num_workers argument can be used to leverage multiple CPU cores and load images in parallel.
# pin_memory arg avoid repeated memory allocation and deallocation by using same portion of memory (RAM) for
# loading each batch of data.
import matplotlib.pyplot as plt
def show_batch(dl):
for images, labels in dl:
fig, ax = plt.subplots(figsize=(16, 16))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(images[:100], 10).permute(1, 2, 0))
break
show_batch(train_dl)
import torch.nn as nn
import torch.nn.functional as F
class SimpleResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
def forward(self, x):
out = self.conv1(x)
out = self.relu1(out)
out = self.conv2(out)
return self.relu2(out + x)
simple_resnet = SimpleResidualBlock()
for images, labels in train_dl:
out = simple_resnet(images)
print(out.shape)
break
# torch.Size([256, 3, 32, 32])
# https://towardsdatascience.com/residual-blocks-building-blocks-of-resnet-fd90ca15d6ec
# After each convolutional layer, adding a batch normalization layer.
def conv_2d(ni, nf, stride=1, ks=3):
return nn.Conv2d(in_channels=ni, out_channels=nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)
def bn_relu_conv(ni, nf):
return nn.Sequential(nn.BatchNorm2d(ni),
nn.ReLU(inplace=True),
conv_2d(ni, nf))
class ResidualBlock(nn.Module):
def __init__(self, ni, nf, stride=1):
super().__init__()
self.bn = nn.BatchNorm2d(ni)
self.conv1 = conv_2d(ni, nf, stride)
self.conv2 = bn_relu_conv(nf, nf)
self.shortcut = lambda x: x
if ni != nf:
self.shortcut = conv_2d(ni, nf, stride, 1)
def forward(self, x):
x = F.relu(self.bn(x), inplace=True)
r = self.shortcut(x)
x = self.conv1(x)
x = self.conv2(x) * 0.2
return x.add_(r)
def make_group(N, ni, nf, stride):
start = ResidualBlock(ni, nf, stride)
rest = [ResidualBlock(nf, nf) for j in range(1, N)]
return [start] + rest
class Flatten(nn.Module):
def __init__(self): super().__init__()
def forward(self, x): return x.view(x.size(0), -1)
class WideResNet(nn.Module):
def __init__(self, n_groups, N, n_classes, k=1, n_start=16):
super().__init__() # Increase channels to n_start using conv layer
layers = [conv_2d(3, n_start)]
n_channels = [n_start]
for i in range(n_groups): # add Groups of BasicBlock(increase channels & downsample)
n_channels.append( n_start * (2**i) *k)
stride = 2 if i>0 else 1
layers += make_group(N, n_channels[i], n_channels[i+1], stride)
layers += [nn.BatchNorm2d(n_channels[3]),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
Flatten(),
nn.Linear(n_channels[3], n_classes)]
self.features = nn.Sequential(*layers)
def forward(self, x): return self.features(x)
def wrn_22():
return WideResNet(n_groups=3, N=3, n_classes=10, k=6)
model = wrn_22()
model
for images, labels in train_dl:
print('images.shape:', images.shape)
out = model(images)
print('out.shape:', out.shape)
break
# Learning rate scheduling refers to the process of dynamically changing the learning rate while the model is being trained.
# Among many strategies, this is 1-cycle policy. Starts with low learning rate, then gradually increases linearly to high value for about
# half of the training, then slowly brings it to original value. At end, training for few iterations with a very low learning rate
# 1-cycle: https://sgugger.github.io/the-1cycle-policy.html
!pip install fastai --upgrade
# Install fastai
from fastai.data.core import DataLoaders
from fastai.learner import Learner
from fastai.metrics import accuracy
from torch.utils.data import DataLoader
import torch.nn.functional as F
# Assuming train_ds and valid_ds are defined
batch_size = 64
path='./data/cifar10'
# Create DataLoaders
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
# # Create DataLoaders object
dls = DataLoaders(train_dl, valid_dl)
# # Define learner with correct loss function argument
learner = Learner(dls, model, loss_func=F.cross_entropy, metrics=[accuracy])#, path=path)
# # Set gradient clipping
learner.clip = 0.1
learner.save('cifar10_trained_model')
# data = DataLoader(train_ds, valid_ds, bs=batch_size, path='./data/cifar10')
# learner = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])
# learner.clip = 0.1
# The statement learner.clip = 0.1 is used to perform gradient clipping, i.e., it limits the values of gradients to the range [-1,1],
# preventing the undesirable changes in the parameters (weights & biases) due to large gradient values.
# FastAI also provides learning rate finder, which tries a range of learning rates, and helps you select a good learning rate
# by looking at the graph of loss vs. learning rate.
learner.lr_find()
learner.recorder.plot_loss()
# learner.fit_one_cycle(9, 5e-3, wd=1e-4)