The VGG (Visual Geometry Group) model is a type of convolutional neural network (CNN) outlined in the paper Very Deep Convolutional Networks for Large-Scale Image Recognition. It’s known for its use of small convolution filters and deep layers, which helped it achieve top-notch performance in tasks like image classification. By stacking multiple layers with small kernel sizes, VGG can capture a wide range of features from input images. Plus, adding more rectification layers makes its decision-making process sharper and more accurate. The paper also introduced 1x1 convolutional layers to enhance nonlinearity without affecting the receptive view. For training, VGG follows the traditional supervised learning approach where input images and ground truth labels are provided.
VGG’s architecture has significantly shaped the field of neural networks, serving as a foundation and benchmark for many subsequent models in computer vision.
In this blog post, we’ll guide you through implementing and training the VGG architecture using PyTorch, step by step. You can find the complete code for defining and training the VGG model on my GitHub repository (URL: https://github.com/JianZhongDev/VGGPyTorch ).
VGG Architecture and Implementation
As you can see in the cover image of this post, the VGG model is made up of multiple layers of convolution followed by max-pooling, and it ends with a few fully connected layers. The output from these layers is then fed into a softmax layer to give a normalized confidence score for each image category.
The key features of the VGG network are these stacked convolutional layers and fully connected layers. We will start with these stacked layers in our implementation.
Stacked Convolutional Layers
To start, we’ll create the stacked convolutional layer as PyTorch nn.Module
, like this:
|
|
The stacked convolutional layer takes in a list of descriptor dictionaries, each detailing the setup for a repeated convolutional layer followed by an activation. It reads these configurations and builds the stacked convolutional layers accordingly. If certain configuration parameters are not specified, the code fills in default values.
Stacked Fully-Connected and Dropout Layers
VGG uses dropout regularizations in their fully connected layers. Adding the dropout regularization within PyTorch is straightforward: we just need to insert dropout layers after each hidden layer inside the stacked fully connected layer. (NOTE: Section 4.2 of the AlexNet paper provides valuable insights into dropout layers. It’s definitely worth a read.)
We can define the stacked fully connected layer in a similar manner as the stacked convolutional layers:
|
|
VGG Model
Now that we’ve defined the stacked convolutional and fully-connected layers, we can construct the VGG model as follows:
|
|
The VGG model takes in a stacked convolutional layer descriptor list, and a fully connected layer descriptor. First, it goes through the convolutional layer descriptors, creating stacked convolutional layers for each descriptor and adding a max pooling layer after each set of stacked convolutional layers. Then, it flattens the output from all the convolutional layers and constructs stacked fully connected layers based on the linear layer descriptor. Finally, a Softmax layer is appended at the end of the network.
Model Generation
Using the model definition provided above, we can create a VGG model by specifying a few layer descriptors. For instance, we can replicate the VGG16 model described in the VGG paper as follows:
|
|
Here’s what the printout of the VGG16 model looks like:
click to expand 16-layer VGG model printout
VGG(
(network): Sequential(
(0): VGGStacked2DConv(
(network): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
)
)
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): VGGStacked2DConv(
(network): Sequential(
(0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
)
)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): VGGStacked2DConv(
(network): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(5): ReLU()
)
)
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): VGGStacked2DConv(
(network): Sequential(
(0): Conv2d(1, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(5): ReLU()
)
)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): VGGStacked2DConv(
(network): Sequential(
(0): Conv2d(1, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(5): ReLU()
)
)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Flatten(start_dim=1, end_dim=-1)
(11): VGGStackedLinear(
(network): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU()
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU()
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
(12): Softmax(dim=-1)
)
)
Data Processing
In the VGG paper, the only data processing done on the input data is subtracting the RGB value calculated from the training set. To apply this processing, we start by going through the entire training dataset and computing the mean value for each color channel.
|
|
Using the mean channel value, we can perform the mentioned data processing by defining a background subtraction function and using the Lambda()
transform provided by torchvision
like this:
|
|
Data Augmentation
The VGG paper also employed various data augmentation techniques to prevent overfitting. Here’s how we implement them:
Random Horizontal Flip
torchvision
already includes a built-in transformation for randomly flipping images horizontally. Therefore, we can simply utilize this built-in transformation for horizontal flips.
|
|
In the VGG paper, they utilized both the original image and its horizontally flipped counterpart to predict classification results. They then averaged these results to obtain the final classification. Consequently, we can implement the validation process as follows:
|
|
Random Color Shift
In VGG, another augmentation technique involved adjusting the RGB values of training images by a random combination of the principal component analysis (PCA) eigenvectors derived from the RGB values across all pixels of all images in the training set. For a detailed explanation, refer to section 4.1 Data Augmentation in the AlexNet paper.
Here’s how the random color shift is implemented:
Before training begins, we go through the entire training set and gather all RGB values from each image. This data is used to create an \(m \times n\) data matrix, where \(n\) represents the number of channels (3 for RGB images) and \(m\) represents the total number of pixels across all images in the training set \(m = \text{number of images} \times \text{image height} \times \text{image width}\). We then calculate the covariance matrix of this data matrix. Next, we conduct principal component analysis (PCA) on the covariance matrix using singular value decomposition (SVD). The resulting \(U\) matrix contains columns representing the PCA eigenvectors, and the \(S\) matrix contains the corresponding eigenvalues.
|
|
Note: In this implementation, all pixels are loaded into computer memory at the same time. For larger datasets, the code for calculating the covariance matrix may need enhancements to compute it without simultaneously loading all data into memory.
During training, we create a randomized linear combination of PCA eigenvectors by adding up the product of each eigenvector with a randomized amplitude. This amplitude is computed by multiplying the corresponding eigenvalue by a random value drawn from a Gaussian distribution with a mean of 0 and a standard deviation of 0.1.
|
|
Other Data Augmentations
The VGG paper also employed additional augmentation techniques like random translations and random crops. However, since the CIFAR dataset’s image size is much smaller (32x32) compared to the ImageNet dataset (256x256), there isn’t much flexibility to utilize these techniques effectively.
Summary of Data Transformations
In summary, the data transformations for the training set, including preprocessing and all data augmentation techniques, can be implemented as follows:
|
|
For the test/validation set, all we need to do is include the preprocessing step in the data transformations.
|
|
Training and Validation
Top k Accuracy (or Error)
In the VGG paper, the main way they measured performance was using the top k error. In my version, I focused on calculating the top k accuracy instead. Top k accuracy shows how often the actual label is among the top k predictions made by the model with the highest confidence. On the other hand, top k error tells us how often the actual label is not included in the top k predictions.
The relationship between top k error and top k accuracy is simply connected by the following formula:
$$ \text{top k error} = 1 − \text{top k accuracy} $$
A higher top k accuracy and lower top k error indicate better model performance.
During the validation (or test) process, the top k-th accuracy can be estimated by dividing the total number of valiation (or test) samples by the number of samples where the label is in the top k predictions.
$$ \text{top k accuracy} = \frac{\text{number of samples (label is in top k predictions)}}{\text{total number of samples}} $$
Therefore, top k accuracy can be calculated using the following code:
|
|
In each batch, we organize the softmax layer results, which represent the confidences for each predicted category, in descending order. Then, we check if the ground truth label is among the top k predictions. This check result is stored in a boolean mask array, where ’true’ indicates the label is in the top k predictions, and ‘false’ indicates it’s not. This boolean mask array holds the results for all samples within the batch. To find the total number of samples where the label is among the top k predictions, we simply sum the mask arrays from all batches.
Loss Function, Regularization, and Optimizer
VGG employs multinomial logistic regression as its loss function. For optimization, it utilizes mini-batch gradient descent with momentum and weight decay. In PyTorch, these can be implemented as follows:
|
|
Additionally, dropout regularization has been incorporated into the model as another form of regularization as mentioned earlier in this post.
Learning Rate Adjustment
In the VGG paper, the authors initially train with a learning rate of 1E-2. Then, they reduce the learning rate by a factor of 10 when the validation set accuracy plateaus. This can be implemented using the ReduceLROnPlateau()
function provided by PyTorch, like this:
|
|
NOTE: The description of the ReduceLROnPlateau()
function in the PyTorch documentation can be confusing. I found that reading the source code of the ReduceLROnPlateau()
definition provides clearer understanding.
Training Deep Models
Optimizing deep models from scratch with completely random initialization can be very challenging for the optimizer. It often leads to the learning process getting stuck for long periods.
To tackle this issue, the VGG authors first train a shallow model. Then, they use the learned parameters from this shallow model to initialize deeper ones.
Transferring learned parameters between models in PyTorch is straightforward. It involves copying the learnable parameters state_dict
(i.e. weights and biases) from corresponding layers between the two models. If you’re using the VGG model definition from this blog post, the example code looks like this:
|
|
NOTE: We’ve organized the sequential layers of the VGG model and stacked them within the “network” attribute of the object. This means we can access each specific layer inside the network by indexing the “network” attribute.
Results
Given the large size of the ImageNet dataset and the extensive time required for training, we’ll opt for a smaller dataset, CIFAR10, to demonstrate training and validation more quickly.
I’ve examined several models based on the VGG architecture, and I’ve listed some of them (model I, II, and III) below:
Model Configuration | ||
I | II | III |
---|---|---|
conv3-128 | conv3-128 | conv3-128 |
maxpool | ||
conv3-256 | conv3-256 | conv3-256 |
conv3-256 | conv3-256 | |
maxpool | ||
conv3-512 | conv3-512 | conv3-512 |
conv3-512 | ||
conv3-512 | ||
conv3-512 | conv3-512 | |
conv3-512 | ||
conv3-512 | ||
maxpool | ||
FC-1024 | FC-1024 | FC-1024 |
FC-1024 | FC-1024 | |
FC-10 | FC-10 | FC-10 |
soft-max |
After training these model variations, I computed the top 1 to top 5 accuracies using the CIFAR10 test dataset. Here’s a summary of the results:
Model config. | top-1 accuarcy(%) | top-2 accuarcy(%) | top-3 accuarcy(%) | top-4 accuarcy(%) | top-5 accuarcy(%) |
---|---|---|---|---|---|
I | 82.45 | 92.74 | 96.23 | 97.82 | 98.87 |
II | 84.88 | 93.95 | 96.91 | 98.23 | 98.99 |
III | 86.93 | 94.39 | 96.83 | 98.15 | 98.90 |
We can observe that the accuracy tends to improve as the depth of the models increases.
Conclusion
In this blog post, we’ve covered the implementation, training, and evaluation of the VGG network in a step-by-step manner. The VGG model showcases the effectiveness of deep neural networks in tackling image classification tasks. Moreover, their methods for data augmentation, regularization, and training provide valuable insights and lessons for training deep neural networks.
Reference
[1] Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for Large-Scale image recognition. arXiv (Cornell University). https://doi.org/10.48550/arxiv.1409.1556
[2] Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). ImageNet Classification with Deep Convolutional Neural Networks. Neural Information Processing Systems, 25, 1097–1105. https://papers.nips.cc/paper_files/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html
[3] Krizhevsky, A., Nair, V. and Hinton, G. (2014) The CIFAR-10 Dataset. https://www.cs.toronto.edu/~kriz/cifar.html
Citation
If you found this article helpful, please cite it as:
Zhong, Jian (May 2024). Building and Training VGG with PyTorch: A Step-by-Step Guide. Vision Tech Insights. https://jianzhongdev.github.io/VisionTechInsights/posts/implement_train_vgg_pytorch/.
Or
@article{zhong2024buildtrainVGGPyTorch,
title = "Building and Training VGG with PyTorch: A Step-by-Step Guide",
author = "Zhong, Jian",
journal = "jianzhongdev.github.io",
year = "2024",
month = "May",
url = "https://jianzhongdev.github.io/VisionTechInsights/posts/implement_train_vgg_pytorch/"
}