Re-visit Convolutional Neural Network (CNN) Architectures: PlainNet & ResNet using Pytorch

In this blog I will use pytorch to rebuild a classic vision model: ResNet-50, to get more hands on exercise for real-world computer vision application

Preparation

Please make sure you have the following packages installed:

torch
torchvision
plotly
IPython

Import the libraries, and check the GPU availability. If you are using a Mac with MPS, it will automatically use the MPS device. Otherwise, it will use the CUDA device if available.

# ============================
# Section 1: General Utilities
# ============================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import plotly.graph_objects as go
from IPython.display import clear_output, display

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

It’s probably overkill to use resnet-50 for cifar10, but the purpose for this exercise for me is to get more familiar with pytorch, and eventually try some new things, don’t want to train a resnet32 again.

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   69C    P8             14W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Helper functions to store the metadata

  • Create a function to store the metadata of the training process, including the loss and accuracy for each epoch. This will help us visualize the training process later.

  • Also a function to overview the model structure.

Here is a mini TensorBoard replacement with Plotly, live in Colab, with:

  • 📈 Dual-axis live plotting (losses + learning rate)

  • 📁 Auto-export to HTML (for website embedding)

  • ✅ Minimal dependencies

  • 😎 Full control, no black-boxing

import plotly.io as pio
import plotly.graph_objects as go

# mont the google drive to save figs
from google.colab import drive
drive.mount('/content/drive')

class ExperimentLogger:
    def __init__(self):
        self.train_losses = []
        self.train_accuracies = []
        self.test_losses = []
        self.test_accuracies = []
        self.lrs = []

    def log_learning_rate(self, lr):
        self.lrs.append(lr)

    def log_train_loss(self, loss):
        self.train_losses.append(loss)

    def log_train_accuracy(self, accuracy):
        self.train_accuracies.append(accuracy)

    def log_test_loss(self, loss):
        self.test_losses.append(loss)

    def log_test_accuracy(self, accuracy):
        self.test_accuracies.append(accuracy)

    def log_train(self, loss, accuracy):
        self.log_train_loss(loss)
        self.log_train_accuracy(accuracy)

    def log_test(self, loss, accuracy):
        self.log_test_loss(loss)
        self.log_test_accuracy(accuracy)

    def live_plot(self, epochs, label):
        clear_output(wait=True)
        fig = go.Figure()

        # Primary y-axis (losses/accuracy)
        fig.add_trace(go.Scatter(y=self.train_losses, mode='lines+markers', name='Training Loss'))
        fig.add_trace(go.Scatter(y=self.train_accuracies, mode='lines+markers', name='Training Accuracy'))
        fig.add_trace(go.Scatter(y=self.test_losses, mode='lines+markers', name='Test Loss'))
        fig.add_trace(go.Scatter(y=self.test_accuracies, mode='lines+markers', name='Test Accuracy'))

        # Secondary y-axis (learning rate)
        if self.lrs:
            fig.add_trace(go.Scatter(y=self.lrs, mode='lines+markers', name='Learning Rate', yaxis='y2'))

        fig.update_layout(
            title=f'{label} - Learning Curves',
            xaxis_title='Epoch',
            yaxis=dict(title='Loss / Accuracy'),
            yaxis2=dict(
                title='Learning Rate',
                overlaying='y',
                side='right',
                showgrid=False
            ),
            width=900,
            height=500,
            legend_title='Metric'
        )

        display(fig)
        
        pio.write_html(fig, file='/content/drive/MyDrive/Colab Notebooks/ResNetPytorch.html', auto_open=False)


    def early_stopping_triggered(self, patience=5, delta=1e-4):
        if len(self.test_accuracies) < patience:
            return False
        recent = self.test_accuracies[-patience:]
        improvement = max(recent) - min(recent)
        return improvement < delta

def show_model_structure(model, name):
    print(f"\n{name} Model Structure:")
    print(model)

Data Preparation

Let’s use the famous cifar10 dataset for this exercise. We will use the torchvision library to download and prepare the dataset. The dataset will be split into training and testing sets.

  • Random horizontal flip

    • Why: Many real-world objects (like animals, vehicles, etc.) are symmetric or can appear flipped in images.
    • Effect: It will increases invariance to orientation, helping the model avoid overfitting to the left/right arrangement of features.
    • Cheap and effective boost to performance.
  • Random cropping (with Padding)

    • Why: Simulates the effect of a subject appearing at slightly different locations in the frame.
    • Effect: Helps the model learn positional robustness — i.e., not overly focusing on centered or fixed-position features.
  • (For ImageNet) Scale jittering (resizing + random crop)

    • Why: Real-world images vary in zoom and size
    • Effect: Teaches the model to recognize objects at multiple scales.
    • Often implemented as: RandomResizedCrop() –> resizes to random scales/aspect ratios before cropping.
    • Especially important for large-scale and varied datasets like ImageNet.

In our case, for most practical cases:

Basic augmentations like these are not overkill, in fact, they are standard practice when training models like ResNet on CIFAR-10.

In the original ResNet paper, even for CIFAR-10, they used:

  • Random crop with 4-pixel padding
  • Random horizontal flip
  • (Oprional) Cutout or Mixup in later works

These are lightweight and significantly boost generalization, especially when you’re training a deep model like Resnet-50 on a relatively small dataset (CIFAR-10 has only 50k training images).

But overkill happens when:

  • You stack too many complex augmentations (e.g., autoaugment, RandAugment, color jitter, etc.) on a simple baseline.
  • Your model is already very small or underfitting — augmentations won’t help and may hurt.
# ============================
# Section 2: Dataset Loading
# ============================

from torchvision.transforms import AutoAugment, AutoAugmentPolicy

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

Model Definition

Create the CNN Architectures, you can get the pretrained resnet-50 from pytorch:

# ============================
# Section 3: CNN Architectures
# ============================


def get_resnet50_for_cifar10():
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

# Define optimizer and learning rate scheduler
model = get_resnet50_for_cifar10()
But you can also implement yourself, click here to see the detailed implementation!

Bottleneck block: The ResNet-50 architecture utilizes a bottleneck design to reduce computation while maintaining performance.

import torch
import torch.nn as nn


# ---------------------------
# Step 1: Define Bottleneck Block
# ---------------------------
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += identity
        out = self.relu(out)

        return out

ResNet-50 Architecture: The ResNet-50 model consists of an initial convolutional layer, followed by four layers of bottleneck blocks, and concludes with a fully connected layer for classification.


# ---------------------------
# Step 2: Define ResNet Framework
# ---------------------------
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # CIFAR-10: no downsample
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        self._initialize_weights()

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = [block(self.in_channels, out_channels, stride, downsample)]
        self.in_channels = out_channels * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)  # conv2_x
        x = self.layer2(x)  # conv3_x
        x = self.layer3(x)  # conv4_x
        x = self.layer4(x)  # conv5_x
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Then, create the model by applying Factory pattern:

# ---------------------------
# Step 3: Factory Function
# ---------------------------

def get_resnet50_for_cifar10():
    # model = models.resnet50(pretrained=False)
    # model.fc = nn.Linear(model.fc.in_features, 10)
    # return model
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)
    
# Define optimizer and learning rate scheduler
model = get_resnet50_for_cifar10()

Where the [3,4,6,3] means:

layer group feature map size bottleneck blocks
layer1 56 x 56 3 blocks
layer2 28 x 28 4 blocks
layer3 14 x 14 6 blocks
layer4 7 x 7 3 blocks

Time to train 🏋️‍♂️

  • Plotly is used to visualize the training process. The live_plot function will be called at the end of each epoch to update the plot with the latest training and testing losses and accuracies. We could be able to monitor the training process in real-time!

How to train the model

Lets first determine what we need to train/test

  • We need to take in a model to train the model
  • We need to take in data, which handled by dataloader: loader object
  • We need to take in an optimizer to optimize the model, it’s optional
  • We need to determin how many epochs we want to train the model
  • Finally, we need to take in a logger object we defined previously to log the training process
# ============================
# Section 4: Model Training & Evaluation
# ============================

def train(model, loader, optimizer, epoch, logger):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = output.argmax(dim=1)
        correct += (pred == target).sum().item()
        total += target.size(0)

    avg_loss = total_loss / len(loader)
    accuracy = correct / total
    print(f"Train Epoch: {epoch}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    logger.log_train(avg_loss, accuracy)

def test(model, loader, logger):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            total_loss += loss.item()

            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)

    avg_loss = total_loss / len(loader)
    acc = correct / total
    print(f"Test Accuracy: {acc:.4f}, Loss: {avg_loss:.4f}")
    logger.log_test(avg_loss, acc)

Gather the training process in one place

We can gather all these into one place, lets say the entire process is considered as an experiment:

# ============================
# Section 5: Run Experiments
# ============================

import copy  # for deep copying model state_dict

def run_experiment(name, model, train_dataset, test_dataset, epochs,
                   optimizer_class=optim.SGD, optimizer_kwargs=None,
                   optimizer=None, lr_scheduler=None, 
                   batch_size=64, val_batch_size=64,
                   patience=None, delta=1e-4, 
                   checkpoint_path=None  # path to save the best model
                   ):
    
    best_model_state = None
    model = model.to(device)
    show_model_structure(model, name)

    logger = ExperimentLogger()

    # Use provided optimizer or construct one
    if optimizer is None:
        optimizer_kwargs = optimizer_kwargs or {"lr": 0.1, "momentum": 0.9}
        optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=val_batch_size, shuffle=False)

    best_accuracy = 0
    epochs_no_improve = 0

    for epoch in range(1, epochs + 1):
        train(model, train_loader, optimizer, epoch, logger)
        test(model, test_loader, logger)

        current_lr = optimizer.param_groups[0]['lr']
        logger.log_learning_rate(current_lr)

        if lr_scheduler:
            lr_scheduler.step()

        if logger.test_accuracies[-1] > best_accuracy + delta:
            best_accuracy = logger.test_accuracies[-1]
            best_model_state = copy.deepcopy(model.state_dict())  # 💾 store best model
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if patience and epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs)")
                break

        logger.live_plot(epoch, name)

    if checkpoint_path and best_model_state:
        torch.save(best_model_state, checkpoint_path)
        print(f"✅ Saved best model to: {checkpoint_path}")

    return model, logger

According to the paper from Kaiming He, in the paper, the authors trained ResNet-50 on the ImageNet dataset using the following settings:

  • Optimizer: Stochastic Gradient Descent (SGD) with momentum
  • Momentum: 0.9
  • Weight Decay: 1e-4
  • Batch Size: 256
  • Initial Learning Rate: 1e-4
  • Learning Rate Schedule: The lr was divided by 10 when the validation error plateaued.

Here’s the interactive Plotly visualization of the training results (without data augmentation):

And here is the visualization of the training results (with data augmentation):