How to create a custom Dataset / Loader in PyTorch, from Scratch, for multi-band Satellite Images Dataset from Kaggle

Update
For information about the course Introduction to Python for Scientists (available on YouTube) and other articles like this, please visit my website cordmaur.carrd.co.
Introduction
In my last Medium story (here) I proposed an approach using the high level API Fast.ai to detect cloud contours in satellite images. Detecting object contours (i.e. all the pixels belonging to the same object) is called semantic segmentation. The dataset that was used for the task is the 38-Cloud: Cloud Segmentation in Satellite Image, from Kaggle.
Although we could achieve a relatively good accuracy of 96% with a few lines of code, the model was not able to consider all the input channels, Red, Green, Blue and NIR(Near Infrared) provided in the dataset. The problem is that most of the semantic segmentation models found in deep learning frameworks like Keras, Fast.ai and even PyTorch are designed to, and come with pre-trained weights, to work with RGB images. Besides that, the vision module of these libraries are also stuck to RGB files. That is the reason we ignored the NIR channel in the previous story and used only RGB patches.
This pushed me to go for a completely different approach here… build my own U-Net (very simple one) from the scratch and not using the vision libraries for preparing the dataset. As a start, this story I will show how to create the dataset using plain PyTorch functions and the next one Creating a Very Simple U-Net Model with PyTorch for Semantic Segmentation of Satellite Images is a continuation that explains how to create a simple segmentation model and train it to achieve great results. Let’s go.
Data Preparation
As I already explained in the last article, this dataset is composed of 8400 training patches of size 384x384 (suitable for deep learning purposes). Another set of 9201 patches is left for testing, but we are not dealing with them for now. The patches are separated in directories for the Red, Green, Blue and NIR (Near Infrared) channels and there is an additional directory for the reference mask (ground truth — *_gt). The structure is shown in Figure 1.

Instead of downloading the whole dataset to your computer (which takes a lot of space), it is possible to create a notebook directly in Kaggle. I published a Kaggle notebook (here) with all the necessary code.
The PyTorch Dataset class
In the last article we created the rgb_patch*.tif files in disk, using PIL to combine the bands into 384x`384x3 png files. With the files on disk, we then created the dataset using fastai.vision’s ImageDataBunch.from_folder function.
Instead of using torchvision to read the files, I decided to create my own dataset class, that reads the Red, Green, Blue and Nir patches and stack them all into a tensor. In order to do that, we need to create a custom PyTorch’s Dataset. To do that, we have to inherit a new class from torch.utils.data.Dataset class. Our new dataset class needs to basically do the following:
- map the files available in the dataset and return the number of items through the
__len__()function, - combine all the 4 channels into a single tensor, without the need to save it to disk, and
- for a given index, return a tuple (x, y) with a sample from the dataset where x is the 4 channels tensor and y is the ground truth mask. For that, we should define the
__getitem__()function.
Our class will receive as input the directories of each channel and start by mapping the files and combining into a dictionary the paths for all the bands.








