avatarYennhi95zz

Summary

The article provides an overview of five data splitting strategies in machine learning, including random split, stratified split, time series split, K-Fold cross-validation, and Leave-One-Out cross-validation, with practical Python examples.

Abstract

The article "How to Split Data in Machine Learning: 5 Simple Strategies and Python Examples" delves into the critical process of data splitting, emphasizing its importance in training, tuning, and evaluating machine learning models. It outlines various strategies such as the random split, which is a basic approach for dividing data, and the stratified split, which maintains class distribution in imbalanced datasets. The article also discusses the time series split, crucial for maintaining the temporal order in time-dependent data. Furthermore, it explains K-Fold cross-validation for robust model evaluation and Leave-One-Out cross-validation, particularly useful for small datasets. Each strategy is accompanied by Python code examples using the scikit-learn library, providing readers with the tools to implement these methods effectively.

Opinions

  • The author suggests that data splitting is not merely about allocating 80% for training and 20% for testing but involves careful consideration of the dataset's characteristics.
  • The article promotes the use of stratified splitting to handle imbalanced datasets, ensuring that the class distribution is consistent across different subsets.
  • The author highlights the importance of preserving the temporal order in time series data, advocating for the use of time series split to simulate real-world predictive scenarios.
  • K-Fold cross-validation is presented as a superior method for evaluating model performance, especially when data is scarce, as it maximizes the utility of available data.
  • The author posits that Leave-One-Out cross-validation is an exhaustive approach suitable for small datasets, providing a comprehensive evaluation of a model's performance.
  • The article encourages the selection of an appropriate data splitting strategy based on the dataset and problem at hand to facilitate successful model development and evaluation.

How to Split Data in Machine Learning: 5 Simple Strategies and Python Examples

Data splitting isn’t just about using 80% for training and 20% for testing. Explore other important aspects.

Table of Contents

This article includes Chapter 4 of the All About Data Preprocessing in Machine Learning collection.

Chapter 4: Data Splitting: The Crucial Divide

  1. Random Split
  2. Stratified Split
  3. Time Series Split
  4. K-Fold Cross-Validation
  5. Leave-One-Out Cross-Validation

Introduction

Data Splitting is an important process in machine learning. This involves separating the data set into distinct subsets such as training set, validation set, and test set. This department is essential for training the model, tuning its parameters, and finally evaluating its performance. In this article, we dig into the importance of Data Splitting and explore some simple strategies, along with practical Python code examples to guide you through the process.

Photo by Mick Haupt on Unsplash

1. Random Split

The random split is a commonly used approach that randomly splits a dataset into a training set, validation set, and test set.

Data Splitting — Random State in Machine Learning (Image by the Author)
  1. Training Set: A training set is used to train a machine learning model. This is the core data set that the model learns from to understand patterns and relationships in the data.
  2. Validation set: Validation sets help you fine-tune your model. Tune hyperparameters and prevent overfitting by evaluating model performance during the training phase.
  3. Test set: The test set provides an unbiased assessment of the model’s performance on unseen data. This is used to evaluate the model’s ability to generalize to new and unknown data.

Below is a simple implementation in Python using the scikit-learn library:

from sklearn.model_selection import train_test_split

X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
  • This function splits the data into a training set (X_train, y_train) and a temporary set (X_temp, y_temp), with 20% of the data allocated to the temporary set.
  • We then further split the temporary set into a validation set (X_val, y_val) and a test set (X_test, y_test), with each set containing 50% of the data.
  • The random_state=42 parameter makes random splits reproducible between different executions of the code.
  • The resulting dataset enables training, validation, and evaluation of machine learning models, facilitating the development of effective and reliable models.
Visualise Train, Test, Validation Set in Machine Learning (Image by the Author)

2. Stratified Split

When a dataset is imbalanced, Stratified Split keeps the class distribution consistent across training, validation, and test sets. Here’s how to implement it in Python:

from sklearn.model_selection import train_test_split

X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

Let’s use a simple example to demonstrate stratified partitioning on a sample data set.

The following dataset is intentionally imbalanced to simulate a real-world scenario where one class is significantly more frequent than the other. In this synthetic dataset, 90% of the instances belong to class 0 and only 10% of the instances belong to class 1. This type of imbalance is common in various domains such as fraud detection, physician diagnosis, and anomaly detection. Apply a stratified split to ensure that the class distribution remains consistent across training, validation, and test sets.

import numpy as np
from sklearn.model_selection import train_test_split

# Creating a synthetic imbalanced dataset
np.random.seed(42)
X = np.random.rand(100, 2)  # Features
y = np.random.choice([0, 1], 100, p=[0.9, 0.1])  # Imbalanced labels with 90% in class 0 and 10% in class 1

# Stratified split
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

# Printing the class distributions in the original dataset and the split sets
print("Original Class Distribution:")
print(f"Class 0: {np.sum(y == 0)}, Class 1: {np.sum(y == 1)}\n")

print("Training Set Class Distribution:")
print(f"Class 0: {np.sum(y_train == 0)}, Class 1: {np.sum(y_train == 1)}\n")

print("Validation Set Class Distribution:")
print(f"Class 0: {np.sum(y_val == 0)}, Class 1: {np.sum(y_val == 1)}\n")

print("Test Set Class Distribution:")
print(f"Class 0: {np.sum(y_test == 0)}, Class 1: {np.sum(y_test == 1)}\n")
# OUTPUT:

Original Class Distribution:
Class 0: 90, Class 1: 10

Training Set Class Distribution:
Class 0: 63, Class 1: 7

Validation Set Class Distribution:
Class 0: 13, Class 1: 2

Test Set Class Distribution:
Class 0: 14, Class 1: 1

You might want to check out my article titled “Credit Card Fraud Detection: A Hands-On Project” where I demonstrate this method of splitting data.

3. Time Series Split

For time-series data, the Time Series Split ensures that temporal order is maintained during Data Splitting.

Time Series Split is a method specifically designed to process time series data. Time series data is a set of observations recorded at different points in time, such as daily stock prices, monthly weather records, or hourly website traffic data. The only challenge with time series data is that the order in which the data points occur is important because the data points often depend on previous observations. In such cases, the Time Series Split is useful. Splitting the dataset into subsets for training, validation, and testing ensures that the temporal order of the data is preserved. How does it work:

  1. The Time Series Split method uses a specific approach, such as his TimeSeriesSplit class from scikit-learn, to create subsets respecting the chronological order of the data.
  2. Instead of randomizing the data as in traditional random splitting, it splits the data into segments, with each segment representing a different time period. For example, if you have daily data for a year, each segment corresponds to one week or one month.
  3. This Time Series Split is very important for time series data as it simulates real world scenarios where future predictions can only be made based on what happened in the past.
  4. By maintaining this temporal order, Time Series Split allows you to train the model using historical data and validate it using more recent data, mimicking the model’s performance in real-world setting.

The following Python code shows how to perform Time Series Split using scikit-learn.

from sklearn.model_selection import TimeSeriesSplit

tscv = TimeSeriesSplit(n_splits=5)
for train_index, test_index in tscv.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

Some related-topics for your reference:

4. K-Fold Cross-Validation

K-Fold Cross-Validation splits the dataset into ‘k’ folds of equal size, allowing for multiple rounds of training and validation.

K-Fold Cross-Validation is a robust technique used to evaluate the performance and generalization ability of machine learning models. It address the challenge of effectively evaluating model performance with limited data by making the most of available samples.

K-Fold Cross-Validation in Machine Learning (Image by the Author)

Math Equation (Image generated by the Author):

where:

  • CV(k)​ represents the cross-validated performance estimate using k folds.
  • Li​ represents the performance metric (e.g., accuracy, error rate) computed for each fold i.
  • The symbol ∑ denotes the summation of the performance metrics across all k folds.
  • The fraction (1/k) computes the average of the performance metrics across all k folds.

Here’s how it works:

How K-Fold Cross-Validation Method works (Image by the Author)
  1. The K-Fold Cross-Validation method divides the dataset into ‘k’ folds or subsets of equal size. Here ‘k’ represents the number of groups or segments into which the data is divided.
  2. Then, perform ‘k’ iterations using each fold as a validation set and the remaining folds as a training set. This process ensures that each fold is used for validation only once and that the model is trained and evaluated on different subsets of the data.
  3. By performing multiple rounds of training and validation, K-Fold Cross-Validation provides a more reliable and accurate assessment of model performance compared to a single train-test split.
  4. This technique is particularly advantageous when datasets are limited, as it makes full use of the available data for both training and validation and allows for a more reliable estimate of the model’s predictive ability.
  5. After ‘k’ iterations are completed, the evaluation results are averaged to provide a more stable and representative estimate of the model’s performance, thereby reducing the impact of data splitting on the overall evaluation.

A simple implementation in Python is:

from sklearn.model_selection import KFold

kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, val_index in kf.split(X):
    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]

5. Leave-One-Out Cross-Validation

Leave-One-Out Cross-Validation (LOOCV) is a comprehensive cross-validation method, especially suitable for small datasets. This involves leaving one data point for validation at each iteration.

Math Equation (Image generated by the Author):

where:

  • N denotes the total number of data points in the dataset.
  • Li​ represents the performance metric (e.g., accuracy, error rate) computed for each iteration where one data point is left out for validation.
  • The symbol ∑ denotes the summation of the performance metrics across all N iterations.
  • The fraction 1/N computes the average of the performance metrics across all iterations, providing an overall evaluation of the model’s performance using the LOOCV method.
How Leave-One-Out Cross-Validation (LOOCV) works (Image by the Author)

Here’s how you can implement it using scikit-learn:

from sklearn.model_selection import LeaveOneOut

loo = LeaveOneOut()
for train_index, val_index in loo.split(X):
    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]

LOOCV vs. K-Fold Cross-Validation Comparison looks familiar, let’s do some comparison:

LOOCV vs. K-Fold Cross-Validation Comparison (Image by the Author)

Conclusion

Data Splitting serves as the cornerstone for building robust and accurate models. By understanding the various Data Splitting strategies and leveraging their implementations in Python, you can ensure an efficient and effective machine learning workflow. Choose the appropriate data splitting strategy based on your specific dataset and problem, and let it pave the way for successful model development and evaluation.

#MachineLearning #DataSplitting #PythonExamples #DataScience #ModelTraining

Machine Learning
Data Splitting
Python Example
Data Science
Model Training
Recommended from ReadMedium