Python > Data Science and Machine Learning Libraries > PyTorch > Building and Training Models
Training a Convolutional Neural Network (CNN) with PyTorch for Image Classification
This snippet shows how to build and train a Convolutional Neural Network (CNN) for image classification using PyTorch on the MNIST dataset. It includes defining the CNN architecture, loading and preprocessing the MNIST dataset, setting up the loss function and optimizer, training the model, and evaluating its performance.
Importing Necessary Libraries
This imports the required libraries. torch
is the main PyTorch library. torch.nn
provides classes for building neural networks. torch.optim
contains optimization algorithms. torchvision
provides access to popular datasets (like MNIST) and image transformation tools. torchvision.transforms
is used for data preprocessing. torch.utils.data.DataLoader
enables efficient data loading and batching.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
Defining the CNN Architecture
This defines the CNN architecture. It consists of two convolutional layers (nn.Conv2d
) followed by ReLU activation functions (nn.ReLU
) and max pooling layers (nn.MaxPool2d
). The output of the convolutional layers is flattened and passed through a fully connected layer (nn.Linear
) to produce the final classification output (10 classes for MNIST digits). The forward
method describes the data flow through the network.
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7)
x = self.fc(x)
return x
Loading and Preprocessing the MNIST Dataset
This loads the MNIST dataset using torchvision.datasets.MNIST
. The transforms.Compose
defines a sequence of image transformations: transforms.ToTensor()
converts images to PyTorch tensors, and transforms.Normalize
normalizes the pixel values to the range [-1, 1], which improves training performance. DataLoader
objects are created to handle batching and shuffling of the data.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values to [-1, 1]
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
Setting up Loss Function and Optimizer
This instantiates the CNN model. nn.CrossEntropyLoss
is the loss function used for multi-class classification. optim.Adam
is the optimization algorithm used to update the model's parameters during training. The learning rate is set to 0.001.
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Training the Model
This is the training loop. It iterates over the specified number of epochs. In each epoch, it iterates over the batches of data provided by the train_loader
. It performs a forward pass to get the model's output, calculates the loss, performs a backward pass to compute the gradients, and updates the model's parameters using the optimizer. optimizer.zero_grad()
clears the gradients from the previous iteration. The loss is printed every 100 steps.
num_epochs = 3
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
Evaluating the Model
This evaluates the trained model on the test dataset. torch.no_grad()
disables gradient calculation during evaluation. It iterates over the batches of data in the test_loader
, performs a forward pass, and compares the predicted labels to the actual labels to calculate the accuracy.
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f} %')
Concepts Behind the Snippet
This snippet demonstrates the core concepts of CNNs: convolutional layers for feature extraction, pooling layers for reducing dimensionality and increasing robustness, and fully connected layers for classification. It also shows how to load and preprocess image data using torchvision
.
Real-Life Use Case
Image classification is used in many real-world applications, such as:
Best Practices
Interview Tip
Be prepared to explain the roles of convolutional layers, pooling layers, and activation functions in CNNs. Also, be ready to discuss how to choose appropriate hyperparameters (e.g., learning rate, batch size, kernel size) and how to prevent overfitting.
When to use them
Use CNNs when you are dealing with image data or other data with spatial relationships. CNNs are particularly effective at automatically learning hierarchical features from raw pixel data.
Memory Footprint
The memory footprint of a CNN depends on the number of layers, the number of filters in each layer, the size of the filters, and the batch size. Deeper and wider CNNs require more memory. Techniques like model compression (e.g., pruning, quantization) can help reduce the memory footprint.
Alternatives
Alternatives to building a CNN from scratch include:
Pros
Cons
FAQ
-
What is the purpose of the `transforms.Normalize` transformation?
transforms.Normalize
normalizes the pixel values of the images to a specific range (usually [-1, 1] or [0, 1]) with a given mean and standard deviation. This helps to improve training stability and convergence by ensuring that the input data has a consistent distribution. -
How do I adjust the CNN architecture for different image sizes or numbers of classes?
To adjust the CNN architecture for different image sizes, you may need to modify the kernel sizes, strides, and padding of the convolutional layers, as well as the size of the fully connected layers. For a different number of classes, you need to change the output size of the final fully connected layer to match the number of classes.