Normalizing Images on PyTorch for Computer Vision Training
Using PyTorch’s torchvision to load image datasets and normalize them by calculating mean and standard deviation.
Background
Normalizing images is a preprocessing step that adjusts the pixel values of an image to a specific range or distribution. Common ranges include [0, 1] or [-1, 1].
For example, if the original pixel values of an image are in the range [0, 255], dividing them by 255 would scale them to the range [0, 1]. This rescaling is often referred to as “min-max scaling.”
The main purpose of normalizing images is to ensure that the data has a consistent scale and distribution, which can lead to more stable and efficient model training, better convergence, and improved model performance.
Questions
So then here is something that I started to wonder when it comes to DICOM images. DICOM images are already in grayscale.
Do we need to normalize a grayscale image?
I would assume that grayscale images are somewhat normalized already because obviously, they are grayscale images. Normalization helps ensure that the pixel values have a consistent scale and distribution. A grey scale image seems like they are pretty consistent (no radical colors, just white, grey, black). However, this isn’t fully true.
Yes, grayscale images are ranging from black to white with intensity of light being the biggest factor in how the value is distributed. BUT — it does not change the fact that the values still range for [0, 255]. 0 is black and 255 is white and the values in between show how intense the light is at that certain pixel in the image.
So… To answer the question above — Yes! Normalization of grayscale images are still useful and should be done!
Normalizing these pixel values to a smaller range (e.g., [0, 1]) can make training more stable and efficient, especially when using activation functions like sigmoid or ReLU.
Then how do we normalize images?
In order to normalize the images, you first have to obtain the mean and the standard deviation of the images. Remember every color image has 3 channels (RGB), so you will have to obtain the mean and standard deviation of all three channels. Additionally, if you are fine-tuning a pre-trained model that was trained on color images, you most likely will also have to obtain mean and standard deviation of all three channels!
If you do not normalize your own images, the pre-trained model will not be able to perform well on them.
In order to obtain these values, you have to first load in the image datasets! — Skip to the mean and standard deviation section if you already know how to load in image datasets with PyTorch! —
Loading in the Dataset
In order to load in image data that you have collected, you will need to use datasets.ImageFolder
from torchvision
.
ImageFolder
import torch
from torchvision import datasets, transforms
data_path = 'path/to/image_data'
image_data = datasets.ImageFolder(root=data_path, transform=transform_img)
The image_data
variable is an instance of the ImageFolder
class that loads images from the specified directory and applies the transform_img
function to them.
Before loading in the data, you also need to make sure that the folder that contains all of the images are separated by classes and each class needs to have it’s own directory.
For example, if you have images of pandas
and koalas
, then you will need to create two separate folders within the train_data
folder with pandas
and koalas
in their respective folder.
# Example of how directories should be made to use datasets.ImageFolder()
train_data/pandas/panda_1.png
train_data/pandas/panda_2.png
train_data/pandas/panda_3.png
train_data/koalas/koalas_1.png
train_data/koalas/koalas_2.png
train_data/koalas/koalas_3.png
Transforming the images
import torch
from torchvision import datasets, transforms
transform_img = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
# this is where you would normally add in a method for normalization
])
data_path = 'path/to/image_data'
image_data = datasets.ImageFolder(root=data_path, transform=transform_img)
As you can see, the ImageFolder
class has a method called transform
which is a function/transform that takes in an PIL (Python Imaging Library) image and returns a transformed version.
With this method, you can resize images with transforms.Resize()
or crop with transforms.CenterCrop()
, transforms.RandomResizedCrop()
, etc.
We’ll also need to convert the images to PyTorch tensors with transforms.ToTensor()
. Typically you'll combine these transforms into a pipeline with transforms.Compose()
, which accepts a list of transforms and runs them in sequence. Generally you’ll go through a sequence of scaling, cropping, and converting the image into a tensor before normalization.
DataLoader
import torch
from torchvision import datasets, transforms
transform_img = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
data_path = 'path/to/image_data'
image_data = datasets.ImageFolder(root=data_path, transform=transform_img)
image_data_loader = DataLoader(
image_data,
batch_size=len(image_data),
shuffle=False,
num_workers=0
)
According to PyTorch’s documentation : datasets
object retrieves the features and labels of a dataset one sample at a time.
However, when training a model, we typically want to pass the samples in batches, shuffle the data at every epoch to prevent the model from overfitting, and use Python's multiprocessing to speed up data retrieval.
The DataLoader
object is an iterable that abstracts this complexity away from us and provides a simple API to access the data. In the DataLoader
, since we are trying to figure out the mean and standard deviation of all of our images, we set the batch_size
to be the same as the length of the image data. Also, we don’t need to do any shuffling of data.
Mean and Standard Deviation
With the above image_data_loader
that we created, now we can calculate the mean and STD of our images.
We will first define a function called mean_std
that takes in a loader
and returns the mean and standard deviation. First we will retrieve all of the imgaes and labels from the image_data_loader using next(iter(loader))
. Since we only have one batch, it will return all of the images in our dataset.
Then we can calculate the mean and the standard deviation that will return a tensor of length 3. Each element in the tensor will represent the mean value for R, G, B respectively.
def mean_std(loader):
# retrives a batch from the loader
images, labels = next(iter(loader))
# shape of images are in [batch, channel, width, height]
mean, std = images.mean([0,2,3]), images.std([0,2,3])
return mean, std
image_data_loader = DataLoader(
image_data,
batch_size=len(image_data),
shuffle=False,
num_workers=0)
mean, std = mean_std(image_data_loader)
print("mean and std: \n", mean, std)
mean and std:
tensor([0.3322, 0.0275, 0.1132]) tensor([0.2215, 0.0965, 0.3152])
Now since we have our mean and standard deviation, we can go back and create a data_loader for our training dataset!
Creating Training Data Loader
Remember how I mentioned that you’ll go through a sequence of scaling, cropping, converting the image into a tensor, and normalization. Well, since we obtained our mean and standard deviation… We can finally get to normalization!!
A quick note: most of real-life scenario images are different sizes in nature. Remember to transform them correctly according to the model (or pre-trained model) you are creating or fine-tuning.
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
transform_img = transforms.Compose([
transforms.Resize(224), # Did you know that VGG16 was trained on 224x224 pixel images?
transforms.ToTensor(),
transform.Normalize(
mean=[0.3322, 0.0275, 0.1132],
std=[0.2215, 0.0965, 0.3152]) # added our normalization
])
data_path = 'path/to/train_data'
train_dataset = datasets.ImageFolder(root=data_path, transform=transform_img)
train_data_loader = DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
)
Awesome! Now your images are normalized and ready to go into training!
Thank you for reading to the end of the post! I hope you guys were able to learn a little something from this! I’m always trying to learn and have been posting content on AI, so give me a follow as well if you are interested!
Also, if you guys are interested in learning more, I learned a lot from some of the sources below, so I really recommend you guys to go check them out.