Summary
This article provides a guide on creating an imbalanced dataset and oversampling under-represented samples using the tf.data.Dataset class in TensorFlow.
Abstract
The article outlines a TensorFlow implementation for generating an imbalanced dataset and subsequently oversampling the minority classes to balance the dataset. It is intended for individuals looking to work with the tf.data.Dataset class, create imbalanced datasets, oversample underrepresented classes, apply data augmentations, and split datasets into training and validation sets. The author uses the Keras library (version 2.4.3) with a TensorFlow (version 2.2.0) backend and demonstrates the process using the CIFAR-10 dataset. Key functions and methods discussed include allowed_id_list(), _filter_list(), ds.take(), ds.skip(), oversample_classes(), map_fn(), and batch_map_fn(). The article also addresses the importance of avoiding data leakage by applying oversampling after dataset splitting and the implications of shuffling within subsets. The author invites feedback and suggests further reading on the tf.data.Dataset class.
Opinions
- The author believes that the current API of TensorFlow's datasets allows for the implementation of an oversampling algorithm, which is crucial for dealing with imbalanced datasets.
- They suggest that readers should execute the
oversample_classes mapping after splitting the data to prevent data leakage.
- The author emphasizes the importance of setting
reshuffle_each_iteration=True in ds.shuffle() to ensure proper shuffling within subsets.
- They encourage readers to engage with the content by leaving comments or sharing the article if they find it helpful or feel that something has been missed.
- The author expresses a desire for support from readers, suggesting they follow the author or sign up for a Medium membership to access more of their content.