avatarAnuj Arora

Summary

This article provides a guide on creating an imbalanced dataset and oversampling under-represented samples using the tf.data.Dataset class in TensorFlow.

Abstract

The article outlines a TensorFlow implementation for generating an imbalanced dataset and subsequently oversampling the minority classes to balance the dataset. It is intended for individuals looking to work with the tf.data.Dataset class, create imbalanced datasets, oversample underrepresented classes, apply data augmentations, and split datasets into training and validation sets. The author uses the Keras library (version 2.4.3) with a TensorFlow (version 2.2.0) backend and demonstrates the process using the CIFAR-10 dataset. Key functions and methods discussed include allowed_id_list(), _filter_list(), ds.take(), ds.skip(), oversample_classes(), map_fn(), and batch_map_fn(). The article also addresses the importance of avoiding data leakage by applying oversampling after dataset splitting and the implications of shuffling within subsets. The author invites feedback and suggests further reading on the tf.data.Dataset class.

Opinions

  • The author believes that the current API of TensorFlow's datasets allows for the implementation of an oversampling algorithm, which is crucial for dealing with imbalanced datasets.
  • They suggest that readers should execute the oversample_classes mapping after splitting the data to prevent data leakage.
  • The author emphasizes the importance of setting reshuffle_each_iteration=True in ds.shuffle() to ensure proper shuffling within subsets.
  • They encourage readers to engage with the content by leaving comments or sharing the article if they find it helpful or feel that something has been missed.
  • The author expresses a desire for support from readers, suggesting they follow the author or sign up for a Medium membership to access more of their content.

Creating an Imbalanced Dataset With Oversampling Using tf.data.Dataset

In this article I summarize the tensorflow implementation for 1) creating an imbalanced dataset, 2) oversampling of under-represented samples using tf.data.Dataset.

Who is this article aimed at? Do you want to

  1. work with tf.data.Dataset class
  2. to create an imbalanced dataset
  3. want to oversample the underrepresented samples in imbalanced dataset
  4. apply image and batch level data augmentations
  5. create a split (train, validation) from a given dataset

Tool and dataset specifications:

  • keras (v2.4.3)
  • tensorflow (v2.2.0) backend
  • cifar-10” dataset from tf.data.Dataset pipeline

TL;DR Below is the bare minimum code snippet that will fulfill these requirements. You can copy and experiment or scroll down to read a brief on what each code block is doing.

Brief

  1. allowed_id_list(): randomly chooses “imbalance_sample_size” of samples for imbalanced classes (:= screened_labels)
  2. _filter_list(): filters out only the allowed samples for the imbalanced classes.
  3. Use ds.take() and ds.skip() to create the train/val splits.
  4. oversample_classes(), repeats the samples from imbalanced classes. The current implementation returns “2” for samples from minority classes and “1” for the rest. For more possibilities, please refer below

5. map_fn() and batch_map_fn() can be used for augmenting a single image and a whole batch of images, respectively

Note:

  1. Make sure to execute oversample_classes mapping after splitting the data into validation and train datasets. Otherwise, there is a possibility of data leakage. Same goes for ds.shuffle().
  2. Not setting “reshuffle_each_iteration=True” in ds.shuffle() restricts (to a certain extent) the shuffling of dataset within subsets.

That covers the basics for implementing an imbalanced dataset or dataset with oversampled classes. Please do leave a comment or share if you feel I have missed something.

If you are interested in further exploring tf.data.Dataset class, you can check out my following article.

If you find stories like these valuable and would like to support me as a writer, please consider following me or signing up for Medium membership.

TensorFlow
Imbalanced Data
Oversampling
Keras
Machine Learning
Recommended from ReadMedium