Learning Day 21: CNN architectures

De Jun Huang
dejunhuang
Published in
2 min readMay 6, 2021

--

Evolution of CNN

LetNet-5

  • 5 layers (2 conv + 3 fully connected)
  • Subsampling (instead of pooling)

AlexNet

  • 8 layers (5 conv + 3 fully connected)
  • Max pooling
  • ReLU activation function
  • Dropout regularisation

VGG (VGG11, 16, 19..)

  • Successive layers of smaller kernel runs faster without sacrificing much accuracy (eg. 3x3 instead of 7x7) as this structure has less weights

GoogLeNet

  • 22 layers
  • Use multiple kernel sizes in a layer

ResNet (ResNet 34, 50, 101, 152..)

  • When layers>20, stacking more layers doesn’t always give better accuracy, gradient vanishing
  • Introduce skip/shortcut connections so that more layers can be stacked
  • Residual connection (the shortcut) can simplify a deep network automatically through learning to, eg. VGG, if there are too many layers than necessary

Simplified ResNet18 implementation

  • Build a Residual block (ResBlk) as a unit first
  • Add ResBlk in the model as needed
  • use self.extra layer (kernel size=1x1) to transform input channel to match output channel
  • (Anyway, the model cannot perform too well.. the loss is decreasing but test accuracy plateaus at about 62% after 30 epochs.. perhaps it’s overfitting. Just for illustration purposes only)

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import optim
# set up a resnet block first
class ResBlk(nn.Module):

def __init__(self, ch_in, ch_out):
super(ResBlk, self).__init__()

self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)

self.extra = nn.Sequential()
if ch_out != ch_in:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
nn.BatchNorm2d(ch_out)
)


def forward(self, x):

out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))

out = self.extra(x) + out
out = F.relu(out) return out# the main model
class ResNet18(nn.Module):

def __init__(self):
super(ResNet18, self).__init__()

self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(16)
)

self.blk1 = ResBlk(16, 16)

self.blk2 = ResBlk(16, 32)


# self.blk3 = ResBlk(128, 256)

# self.blk4 = ResBlk(256, 512)

self.outlayer = nn.Linear(32*32*32, 10)

def forward(self, x):

x = F.relu(self.conv1(x))

x = self.blk1(x)
x = self.blk2(x)
# x = self.blk3(x)
# x = self.blk4(x)

x = x.reshape(x.shape[0], -1)
x = self.outlayer(x)

return x

def main():
batchsz = 32

cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

x, label = iter(cifar_train).next()
print(f"x shape: {x.shape}, label shape: {label.shape}")

device = torch.device('cuda')
model = ResNet18().to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)

for epoch in range(1000):

model.train()
for batchidx, (x, label) in enumerate(cifar_train):

x, label = x.to(device), label.to(device)

pred = model(x)
loss = criterion(pred, label)

# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()



print(f"epoch: {epoch}, loss: {loss.item()}")

model.eval()
with torch.no_grad():

total_correct = 0
total_num = 0
for x, label in cifar_test:

x, label = x.to(device), label.to(device)

pred = model(x)
pred = pred.argmax(dim=1)
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.shape[0]

acc = total_correct / total_num
print(f"epoch: {epoch}, acc: {acc}")

if __name__ == '__main__':
main()

Reference

link1

--

--