Training Resnet Using Pytorch
Re-visit Convolutional Neural Network (CNN) Architectures: PlainNet & ResNet using Pytorch
In this blog I will use pytorch to rebuild a classic vision model: ResNet-50
, to get more hands on exercise for real-world computer vision application
Preparation
Please make sure you have the following packages installed:
torch
torchvision
plotly
IPython
Import the libraries, and check the GPU availability. If you are using a Mac with MPS, it will automatically use the MPS device. Otherwise, it will use the CUDA device if available.
# ============================
# Section 1: General Utilities
# ============================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import plotly.graph_objects as go
from IPython.display import clear_output, display
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
print('Not connected to a GPU')
else:
print(gpu_info)
It’s probably overkill to use resnet-50 for cifar10, but the purpose for this exercise for me is to get more familiar with pytorch, and eventually try some new things, don’t want to train a resnet32 again.
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 69C P8 14W / 70W | 2MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Helper functions to store the metadata
-
Create a function to store the metadata of the training process, including the loss and accuracy for each epoch. This will help us visualize the training process later.
-
Also a function to overview the model structure.
Here is a mini TensorBoard replacement with Plotly, live in Colab, with:
-
📈 Dual-axis live plotting (losses + learning rate)
-
📁 Auto-export to HTML (for website embedding)
-
✅ Minimal dependencies
-
😎 Full control, no black-boxing
import plotly.io as pio
import plotly.graph_objects as go
# mont the google drive to save figs
from google.colab import drive
drive.mount('/content/drive')
class ExperimentLogger:
def __init__(self):
self.train_losses = []
self.train_accuracies = []
self.test_losses = []
self.test_accuracies = []
self.lrs = []
def log_learning_rate(self, lr):
self.lrs.append(lr)
def log_train_loss(self, loss):
self.train_losses.append(loss)
def log_train_accuracy(self, accuracy):
self.train_accuracies.append(accuracy)
def log_test_loss(self, loss):
self.test_losses.append(loss)
def log_test_accuracy(self, accuracy):
self.test_accuracies.append(accuracy)
def log_train(self, loss, accuracy):
self.log_train_loss(loss)
self.log_train_accuracy(accuracy)
def log_test(self, loss, accuracy):
self.log_test_loss(loss)
self.log_test_accuracy(accuracy)
def live_plot(self, epochs, label):
clear_output(wait=True)
fig = go.Figure()
# Primary y-axis (losses/accuracy)
fig.add_trace(go.Scatter(y=self.train_losses, mode='lines+markers', name='Training Loss'))
fig.add_trace(go.Scatter(y=self.train_accuracies, mode='lines+markers', name='Training Accuracy'))
fig.add_trace(go.Scatter(y=self.test_losses, mode='lines+markers', name='Test Loss'))
fig.add_trace(go.Scatter(y=self.test_accuracies, mode='lines+markers', name='Test Accuracy'))
# Secondary y-axis (learning rate)
if self.lrs:
fig.add_trace(go.Scatter(y=self.lrs, mode='lines+markers', name='Learning Rate', yaxis='y2'))
fig.update_layout(
title=f'{label} - Learning Curves',
xaxis_title='Epoch',
yaxis=dict(title='Loss / Accuracy'),
yaxis2=dict(
title='Learning Rate',
overlaying='y',
side='right',
showgrid=False
),
width=900,
height=500,
legend_title='Metric'
)
display(fig)
pio.write_html(fig, file='/content/drive/MyDrive/Colab Notebooks/ResNetPytorch.html', auto_open=False)
def early_stopping_triggered(self, patience=5, delta=1e-4):
if len(self.test_accuracies) < patience:
return False
recent = self.test_accuracies[-patience:]
improvement = max(recent) - min(recent)
return improvement < delta
def show_model_structure(model, name):
print(f"\n{name} Model Structure:")
print(model)
Data Preparation
Let’s use the famous cifar10
dataset for this exercise. We will use the torchvision
library to download and prepare the dataset. The dataset will be split into training and testing sets.
-
Random horizontal flip
Why
: Many real-world objects (like animals, vehicles, etc.) are symmetric or can appear flipped in images.Effect
: It will increases invariance to orientation, helping the model avoid overfitting to the left/right arrangement of features.- Cheap and effective boost to performance.
-
Random cropping (with Padding)
Why
: Simulates the effect of a subject appearing at slightly different locations in the frame.Effect
: Helps the model learn positional robustness — i.e., not overly focusing on centered or fixed-position features.
-
(For ImageNet) Scale jittering (resizing + random crop)
Why
: Real-world images vary in zoom and sizeEffect
: Teaches the model to recognize objects at multiple scales.- Often implemented as:
RandomResizedCrop()
–> resizes to random scales/aspect ratios before cropping. - Especially important for large-scale and varied datasets like ImageNet.
In our case, for most practical cases:
Basic augmentations like these are not overkill, in fact, they are standard practice when training models like ResNet on CIFAR-10.
In the original ResNet paper, even for CIFAR-10, they used:
- Random crop with 4-pixel padding
- Random horizontal flip
- (Oprional) Cutout or Mixup in later works
These are lightweight and significantly boost generalization, especially when you’re training a deep model like Resnet-50 on a relatively small dataset (CIFAR-10 has only 50k training images).
But overkill happens when:
- You stack too many complex augmentations (e.g., autoaugment, RandAugment, color jitter, etc.) on a simple baseline.
- Your model is already very small or underfitting — augmentations won’t help and may hurt.
# ============================
# Section 2: Dataset Loading
# ============================
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
Model Definition
Create the CNN Architectures, you can get the pretrained resnet-50 from pytorch:
# ============================
# Section 3: CNN Architectures
# ============================
def get_resnet50_for_cifar10():
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)
return model
# Define optimizer and learning rate scheduler
model = get_resnet50_for_cifar10()
But you can also implement yourself, click here to see the detailed implementation!
Bottleneck block
: The ResNet-50 architecture utilizes a bottleneck design to reduce computation while maintaining performance.
import torch
import torch.nn as nn
# ---------------------------
# Step 1: Define Bottleneck Block
# ---------------------------
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += identity
out = self.relu(out)
return out
ResNet-50 Architecture
:
The ResNet-50 model consists of an initial convolutional layer, followed by four layers of bottleneck blocks, and concludes with a fully connected layer for classification.
# ---------------------------
# Step 2: Define ResNet Framework
# ---------------------------
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # CIFAR-10: no downsample
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self._initialize_weights()
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = [block(self.in_channels, out_channels, stride, downsample)]
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.layer1(x) # conv2_x
x = self.layer2(x) # conv3_x
x = self.layer3(x) # conv4_x
x = self.layer4(x) # conv5_x
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
Then, create the model by applying Factory pattern:
# ---------------------------
# Step 3: Factory Function
# ---------------------------
def get_resnet50_for_cifar10():
# model = models.resnet50(pretrained=False)
# model.fc = nn.Linear(model.fc.in_features, 10)
# return model
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)
# Define optimizer and learning rate scheduler
model = get_resnet50_for_cifar10()
Where the [3,4,6,3]
means:
layer group | feature map size | bottleneck blocks |
---|---|---|
layer1 |
56 x 56 | 3 blocks |
layer2 |
28 x 28 | 4 blocks |
layer3 |
14 x 14 | 6 blocks |
layer4 |
7 x 7 | 3 blocks |
Time to train 🏋️♂️
- Plotly is used to visualize the training process. The
live_plot
function will be called at the end of each epoch to update the plot with the latest training and testing losses and accuracies. We could be able to monitor the training process in real-time!
How to train the model
Lets first determine what we need to train/test
- We need to take in a
model
to train the model - We need to take in data, which handled by dataloader:
loader
object - We need to take in an
optimizer
to optimize the model, it’s optional - We need to determin how many
epoch
s we want to train the model - Finally, we need to take in a
logger
object we defined previously to log the training process
# ============================
# Section 4: Model Training & Evaluation
# ============================
def train(model, loader, optimizer, epoch, logger):
model.train()
total_loss = 0
correct = 0
total = 0
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
avg_loss = total_loss / len(loader)
accuracy = correct / total
print(f"Train Epoch: {epoch}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
logger.log_train(avg_loss, accuracy)
def test(model, loader, logger):
model.eval()
correct = 0
total = 0
total_loss = 0
with torch.no_grad():
for data, target in loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.cross_entropy(output, target)
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
avg_loss = total_loss / len(loader)
acc = correct / total
print(f"Test Accuracy: {acc:.4f}, Loss: {avg_loss:.4f}")
logger.log_test(avg_loss, acc)
Gather the training process in one place
We can gather all these into one place, lets say the entire process is considered as an experiment:
# ============================
# Section 5: Run Experiments
# ============================
import copy # for deep copying model state_dict
def run_experiment(name, model, train_dataset, test_dataset, epochs,
optimizer_class=optim.SGD, optimizer_kwargs=None,
optimizer=None, lr_scheduler=None,
batch_size=64, val_batch_size=64,
patience=None, delta=1e-4,
checkpoint_path=None # path to save the best model
):
best_model_state = None
model = model.to(device)
show_model_structure(model, name)
logger = ExperimentLogger()
# Use provided optimizer or construct one
if optimizer is None:
optimizer_kwargs = optimizer_kwargs or {"lr": 0.1, "momentum": 0.9}
optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=val_batch_size, shuffle=False)
best_accuracy = 0
epochs_no_improve = 0
for epoch in range(1, epochs + 1):
train(model, train_loader, optimizer, epoch, logger)
test(model, test_loader, logger)
current_lr = optimizer.param_groups[0]['lr']
logger.log_learning_rate(current_lr)
if lr_scheduler:
lr_scheduler.step()
if logger.test_accuracies[-1] > best_accuracy + delta:
best_accuracy = logger.test_accuracies[-1]
best_model_state = copy.deepcopy(model.state_dict()) # 💾 store best model
epochs_no_improve = 0
else:
epochs_no_improve += 1
if patience and epochs_no_improve >= patience:
print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs)")
break
logger.live_plot(epoch, name)
if checkpoint_path and best_model_state:
torch.save(best_model_state, checkpoint_path)
print(f"✅ Saved best model to: {checkpoint_path}")
return model, logger
According to the paper from Kaiming He, in the paper, the authors trained ResNet-50 on the ImageNet dataset using the following settings:
- Optimizer: Stochastic Gradient Descent (SGD) with momentum
- Momentum: 0.9
- Weight Decay: 1e-4
- Batch Size: 256
- Initial Learning Rate: 1e-4
- Learning Rate Schedule: The lr was divided by 10 when the validation error plateaued.
Here’s the interactive Plotly visualization of the training results (without data augmentation):
And here is the visualization of the training results (with data augmentation):