avatarSumit Pandey

Summary

The provided content is a comprehensive tutorial on using DINOv2, a state-of-the-art self-supervised vision transformer, for custom dataset segmentation, detailing the process from library installation to model training and evaluation.

Abstract

The article presents a step-by-step guide on employing DINOv2 for semantic segmentation tasks on a custom dataset. It begins by acknowledging the model's anticipation in the computer vision community and references a GitHub repository that inspired the tutorial. The author describes encountering bugs during the implementation of the code from the repository, which led to training interruptions. The tutorial outlines a plan of attack that includes an introduction to DINOv2, installation of necessary libraries, loading of the FoodSeg dataset, creation of a PyTorch dataset and dataloaders, definition of the model architecture, and the training process. The model uses a frozen DINOv2 backbone with a linear transformation layer for mapping features to logits, and the training involves a single epoch with a focus on the linear classifier while keeping the DINOv2 parameters frozen. The results demonstrate promising performance with a mean IOU of 0.37 after just one epoch of training, and the article concludes with a link to the updated Colab notebook and a reference to the DINOv2 paper.

Opinions

  • The author motivates the tutorial by referencing the excitement around DINOv2 and the challenges faced while following an existing tutorial, implying a need for clearer guidance.
  • The author expresses satisfaction with the results achieved after training for only one epoch, indicating the effectiveness of DINOv2's features for downstream tasks.
  • There is an acknowledgment of the non-optimized training hyperparameters, suggesting that further fine-tuning could improve the model's performance.
  • The author encourages readers to experiment with the model and hyperparameters, highlighting the exploratory nature of the tutorial.
  • By providing a link to the Colab notebook, the author invites the community to engage with and build upon the work presented in the tutorial.
  • The author promotes an AI service, ZAI.chat, as a cost-effective alternative to ChatGPT Plus, indicating a preference or endorsement of this service.

DINOv2 for Custom Dataset Segmentation: A Comprehensive Tutorial.

After YOLOv8 and SAM (Segment Anything Model), the most anticipated computer vision model is DINOv2. I got the motivation for this tutorial from this GitHub repository: https://github.com/NielsRogge/Transformers-Tutorials/tree/master, while running the code, I found 2 bugs because of that, I got some annoying errors while training the model (in his tutorial, he stopped the training the process after some steps and error arises in between and at last training step). The entire code is taken from his notebook (except for some changes :) ), and here is the plan of attack:

Plan of Attack

  1. Introduction of DINOv2
  2. Library installation
  3. Load dataset
  4. Create PyTorch dataset
  5. Create PyTorch dataloaders
  6. Define model
  7. Train the model

Introduction of DINOv2

DINOv2 is a vision transformer that has been trained in a self-supervised manner on a meticulously curated dataset of 142 million images. It offers the best image features, or embeddings, available for downstream tasks such as image classification, image segmentation, and depth estimation.

Figure 1: The model is completely working in this tutorial. (output image when the model is trained for just one epoch) (image by author)

Figure 1 conceptualizes this approach, in this tutorial, I am simply training a linear transformation (1*1 CNN layer) on top of a frozen DINOv2 backbone. This transformation will map the features (patch embeddings) to logits (the unnormalized scores output by neural networks, indicative of the model’s predictions). In the context of semantic segmentation, the logits will take the shape of (batch_size, num_classes, height, and width), corresponding to a predicted class for each pixel.

Library installation

Here are two main libraries:

!pip install -q git+https://github.com/huggingface/transformers.git datasets

!pip install -q evaluate

Load dataset

Next, let’s load an image segmentation dataset. In this case, we’ll use the Foodseg dataset.

from datasets import load_dataset

#dataset
dataset = load_dataset("EduardoPacheco/FoodSeg103")

#lables
id2label = {
    0: "background",
    1: "candy",
    2: "egg tart",
    3: "french fries",
    4: "chocolate",
    5: "biscuit",
    6: "popcorn",
    7: "pudding",
    8: "ice cream",
    9: "cheese butter",
    10: "cake",
    11: "wine",
    12: "milkshake",
    13: "coffee",
    14: "juice",
    15: "milk",
    16: "tea",
    17: "almond",
    18: "red beans",
    19: "cashew",
    20: "dried cranberries",
    21: "soy",
    22: "walnut",
    23: "peanut",
    24: "egg",
    25: "apple",
    26: "date",
    27: "apricot",
    28: "avocado",
    29: "banana",
    30: "strawberry",
    31: "cherry",
    32: "blueberry",
    33: "raspberry",
    34: "mango",
    35: "olives",
    36: "peach",
    37: "lemon",
    38: "pear",
    39: "fig",
    40: "pineapple",
    41: "grape",
    42: "kiwi",
    43: "melon",
    44: "orange",
    45: "watermelon",
    46: "steak",
    47: "pork",
    48: "chicken duck",
    49: "sausage",
    50: "fried meat",
    51: "lamb",
    52: "sauce",
    53: "crab",
    54: "fish",
    55: "shellfish",
    56: "shrimp",
    57: "soup",
    58: "bread",
    59: "corn",
    60: "hamburg",
    61: "pizza",
    62: "hanamaki baozi",
    63: "wonton dumplings",
    64: "pasta",
    65: "noodles",
    66: "rice",
    67: "pie",
    68: "tofu",
    69: "eggplant",
    70: "potato",
    71: "garlic",
    72: "cauliflower",
    73: "tomato",
    74: "kelp",
    75: "seaweed",
    76: "spring onion",
    77: "rape",
    78: "ginger",
    79: "okra",
    80: "lettuce",
    81: "pumpkin",
    82: "cucumber",
    83: "white radish",
    84: "carrot",
    85: "asparagus",
    86: "bamboo shoots",
    87: "broccoli",
    88: "celery stick",
    89: "cilantro mint",
    90: "snow peas",
    91: "cabbage",
    92: "bean sprouts",
    93: "onion",
    94: "pepper",
    95: "green beans",
    96: "French beans",
    97: "king oyster mushroom",
    98: "shiitake",
    99: "enoki mushroom",
    100: "oyster mushroom",
    101: "white button mushroom",
    102: "salad",
    103: "other ingredients"
}

# visualize the images and masks
import numpy as np
import matplotlib.pyplot as plt

# map every class to a random color
id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()}

def visualize_map(image, segmentation_map):
    color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
    for label, color in id2color.items():
        color_seg[segmentation_map == label, :] = color

    # Show image + mask
    img = np.array(image) * 0.5 + color_seg * 0.5
    img = img.astype(np.uint8)

    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.show()

visualize_map(image, segmentation_map)

Create PyTorch dataset

To prepare examples for the model, we create a standard PyTorch dataset that includes image augmentations. We randomly resize and crop the training images to a uniform resolution of 448x448 pixels and normalize the color channels, ensuring all training images are of the same fixed resolution. For this process, we employ the Albumentations library, although it’s worth noting that other libraries, such as Torchvision or Kornia, can also serve this purpose.

It’s important to remember that the model expects input pixel_values with the shape (batch_size, num_channels, height, width). Since Albumentations operates on NumPy arrays, which use a channels-last format, we need to reorder the dimensions to place the channels first. In addition, the model requires labels in the shape of (batch_size, height, and width), which provide the ground truth label for each pixel in every example of the batch.

from torch.utils.data import Dataset
import torch

class SegmentationDataset(Dataset):
  def __init__(self, dataset, transform):
    self.dataset = dataset
    self.transform = transform

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    original_image = np.array(item["image"])
    original_segmentation_map = np.array(item["label"])

    transformed = self.transform(image=original_image, mask=original_segmentation_map)
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask'])

    # convert to C, H, W
    image = image.permute(2,0,1)

    return image, target, original_image, original_segmentation_map


# Let's create the training and validation datasets (note that we only randomly crop for training images).

import albumentations as A

ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

train_transform = A.Compose([
    # hadded an issue with an image being too small to crop, PadIfNeeded didn't help...
    # if anyone knows why this is happening I'm happy to read why
    # A.PadIfNeeded(min_height=448, min_width=448),
    # A.RandomResizedCrop(height=448, width=448),
    A.Resize(width=448, height=448),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
], is_check_shapes=False)

val_transform = A.Compose([
    A.Resize(width=448, height=448),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),

], is_check_shapes=False)

train_dataset = SegmentationDataset(dataset["train"], transform=train_transform)
val_dataset = SegmentationDataset(dataset["validation"], transform=val_transform)

pixel_values, target, original_image, original_segmentation_map = train_dataset[3]
print(pixel_values.shape)
print(target.shape)

Create PyTorch dataloaders

Next, we create PyTorch dataloaders, which allow us to get batches of data (as neural networks are trained on batches using stochastic gradient descent or SGD). We just stack the various images and labels along a new batch dimension.

from torch.utils.data import DataLoader

def collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]

    return batch

train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=3, shuffle=False, collate_fn=collate_fn)

batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

Define model

Next, we define the model, which comprises DINOv2 as the backbone, along with a linear classifier on top. DINOv2 is a standard vision transformer, and thus, it produces “patch embeddings,” which means an embedding vector for each image patch. Given that we use an image resolution of 448 pixels and a DINOv2 model with a patch resolution of 14, as shown here, we obtain (448/14)² = 1024 patches. Consequently, the model outputs a tensor with the shape (batch_size, number of patches, hidden_size), or (batch_size, 1024, 768), for a batch of images (the model features a hidden size — or embedding dimension — of 768, as indicated here).

Subsequently, we reshape this tensor to (batch_size, 32, 32, 768). Following this, we apply the linear layer (implemented here as a Conv2D layer, which acts as a linear transformation when using a kernel size of 1x1). This Conv2D layer transforms the patch embeddings into a logit tensor of shape (batch_size, num_labels, height, width), which is requisite for semantic segmentation. This tensor contains the scores predicted by the model for all the classes, for each pixel, for every example in the batch.

import torch
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput

class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.classifier(embeddings)


class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)

    self.dinov2 = Dinov2Model(config)
    self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels)

  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
    # use frozen features
    outputs = self.dinov2(pixel_values,
                            output_hidden_states=output_hidden_states,
                            output_attentions=output_attentions)
    # get the patch embeddings - so we exclude the CLS token
    patch_embeddings = outputs.last_hidden_state[:,1:,:]

    # convert to logits and upsample to the size of the pixel values
    logits = self.classifier(patch_embeddings)
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

    loss = None
    if labels is not None:
      # important: we're going to use 0 here as ignore index instead of the default -100
      # as we don't want the model to learn to predict background
      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)
      loss = loss_fct(logits.squeeze(), labels.squeeze())

    return SemanticSegmenterOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

#We can instantiate the model as follows:

model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))

#Important: we don't want to train the DINOv2 backbone, only the linear classification head. Hence we don't want to track any gradients for the backbone parameters. This will greatly save us in terms of memory used:

for name, param in model.named_parameters():
  if name.startswith("dinov2"):
    param.requires_grad = False

#Let's perform a forward pass on a random batch, to verify the shape of the logits, verify we can calculate a loss:

outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
print(outputs.logits.shape)
print(outputs.loss)


import evaluate
metric = evaluate.load("mean_iou")

Train the model

Now let's train the model for one epoch:

from torch.optim import AdamW
from tqdm.auto import tqdm

# training hyperparameters
# NOTE: I've just put some random ones here, not optimized at all
# feel free to experiment, see also DINOv2 paper
learning_rate = 5e-5
epochs = 1

optimizer = AdamW(model.parameters(), lr=learning_rate)

# put model on GPU (set runtime to GPU in Google Colab)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# put model in training mode
model.train()

for epoch in range(epochs):
  print("Epoch:", epoch)
  for idx, batch in enumerate(tqdm(train_dataloader)):
      pixel_values = batch["pixel_values"].to(device)
      labels = batch["labels"].to(device)

      # forward pass
      outputs = model(pixel_values, labels=labels)
      loss = outputs.loss

      loss.backward()
      optimizer.step()

      # zero the parameter gradients
      optimizer.zero_grad()

      # evaluate
      with torch.no_grad():
        predicted = outputs.logits.argmax(dim=1)

        # note that the metric expects predictions + labels as numpy arrays
        metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

      # let's print loss and metrics every 100 batches
      if idx % 100 == 0:
        metrics = metric.compute(num_labels=len(id2label),
                                ignore_index=0,
                                reduce_labels=False,
        )

        print("Loss:", loss.item())
        print("Mean_iou:", metrics["mean_iou"])
        print("Mean accuracy:", metrics["mean_accuracy"])

Results

Here are the results I have only trained the model for one epoch, and we can see the mean IOU has already reached 0.37 (as shown in Result Figure 1)

Result figure 1: IOU and loss matrices plot (image by author)

And here are randomly selected results, shown in Result figure 2.

Result figure 2: Segmentation results on test dataset after 1 epoch (image by author)

Done 😃😃😃😃😃

Please feel free to take a look at the updated colab Notebook link: https://colab.research.google.com/drive/1UMQj7F_x0fSy_gevlTZ9zLYn7b02kTqi?usp=sharing

References

DINOv2: Learning Robust Visual Features without Supervision: https://arxiv.org/abs/2304.07193

If you like this work, then please share it with your friends like and follow me here on medium

AI
Deep Learning
Computer Vision
Segmentation
Machine Learning
Recommended from ReadMedium