avatarAngelica Lo Duca

Summary

The website content provides a tutorial on how to update a pre-trained machine learning model with new data using Scikit-learn's warm_start parameter and partial_fit() method.

Abstract

The article on the website offers a step-by-step guide for practitioners looking to incrementally train machine learning models using Scikit-learn. It addresses the challenge of retraining models when new data becomes available, emphasizing the inefficiency of retraining from scratch, especially for large datasets or models that require extensive computational resources. The author illustrates two strategies: using warm_start=True to add new estimators to an existing ensemble model, such as a Random Forest Classifier, and employing the partial_fit() method available in certain models like the SGDClassifier. The tutorial uses the Iris dataset to demonstrate how these methods can improve model performance without discarding previously learned information. The article concludes by encouraging readers to explore the Scikit-learn documentation for models that support incremental learning and to check out the author's Github repository for the code examples used in the tutorial.

Opinions

  • The author suggests that using warm_start=True and partial_fit() can be more efficient than retraining a model from scratch when new data is available.
  • It is noted that warm_start=True should not be used in cases where there is concept drift in the data.
  • The article implies that incremental learning can lead to better model performance, as demonstrated by the improved score when using warm_start=True with additional data.
  • The author expresses that not all models in Scikit-learn support incremental learning, highlighting the importance of checking the documentation.
  • There is an encouragement to engage with the author's other content, suggesting that the reader may find additional value in the author's broader body of work.

Machine Learning

How to Add New Data to a Pretrained Model in Scikit-learn

A step-by-step tutorial on how to use warm_start=True and partial_fit() in scikit-learn

Photo by h heyerlein on Unsplash

When you build a Machine Learning model from scratch, usually, you split your dataset into training and test set, and then you train your model on your training set. Then, you test the performance of your model on your test set, and if you get something decent, you can use your model for prediction.

But what if new data becomes available at some point?

In other words, how to train an already trained model? Or again, how to add new data to an already trained model?

In this article I try to give some answers to this non-trivial question, using the scikit-learn library. You can check this interesting article by Vidhi Chugh to understand when you need to retrain your model.

One possible (trivial) solution to the previous question, could be to train the model from scratch, by using both old and new data. However, this solution does not scale, if the first training requires a long time.

The solution to the problem is to add samples to an already trained model. And this scikit-learn allows you to do it in some cases. Just follow some precautions.

Scikit-learn proposes two strategies:

  • partial fit
  • warm start

To illustrate how to add new data to a pre-trained model in Scikit-learn, I will use a practical example, using the well-known iris dataset, provided by the Scikit-learn library.

warm start

A warm start is a parameter provided by some Scikit-models. If it is set to True, it permits the use of the existing fitted model attributes to initialize a new model in a subsequent call to fit.

For example, you can set warm_start = True in a Random Forest Classifier, then you can fit the model regularly. If you call again the fit method on new data, new estimators will be added to the existing trees. This means that the use of warm_start = True does not change the existing trees.

warm_start = True should not be used for incremental learning on new datasets where there could be concept drift. Concept drift is a type of drift in the data model, which happens when the underlying relationship between the output and the input variables changes.

To understand how warm_start = True works, I describe an example. The idea is to show that the use of warm_start = True could improve the performance of an algorithm if I add new data, that has the same distribution as the original data and which maintains the same relationship with the output variable.

Firstly, I load the iris dataset, provided by the Scikit-learn library:

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target

Then, I split the dataset into three parts:

  • X_train, y_train — training set 80% of 40% of data (48 samples)
  • X_test, y_test — test set 20% of 40 of data (12 samples)
  • X2, y2 — new samples (60% of data) (90 samples)
from sklearn.model_selection import train_test_split
X1, X2, y1, y2 = train_test_split(X, y, test_size=0.60, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X1, y1, test_size=0.20, random_state=42)

I will use X2 and y2 to retrain the model.

Note that the training set is very small (48 samples).

I train the model, with warm_start = False:

from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(max_depth=2, random_state=0, warm_start=False, n_estimators=1)
model.fit(X_train, y_train)

I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.75

Now, I fit the model on new data:

model.fit(X2, y2)

The previous fit deletes the model already learned. Then, I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.8333333333333334

Now I build a new model with warm_start = True, to see if the model score increases.

model = RandomForestClassifier(max_depth=2, random_state=0, warm_start=True, n_estimators=1)
model.fit(X_train, y_train)
model.score(X_test, y_test)

which gives the following output:

0.75

Now, I fit the model and I calculate the score:

model.n_estimators+=1
model.fit(X2, y2)
model.score(X_test, y_test)

which gives the following output:

0.9166666666666666

The incremental learning has improved the score!

partial fit

The second strategy provided by Scikit-learn to add new data to a pre-trained model is the use of the partial_fit() method. Not all the models provide this method.

While the warm_start = True parameter does not change the attribute parameters already learned by the model, the partial fit could change it because it learns from new data.

I consider again the iris dataset.

Now I use a SGDClassifier:

from sklearn.linear_model import SGDClassifier
import numpy as np
model = SGDClassifier() 
model.partial_fit(X_train, y_train, classes=np.unique(y))

The first time I run the partial_fit() method, I must pass to the method also all the classes. In this. example, I suppose that I know all the classes contained in y, although, I do not have enough samples to represent them.

I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.4166666666666667

Now, I add new samples to the model:

model.partial_fit(X2, y2)

and I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.8333333333333334

Adding new data has improved the performance of the algorithm!

Summary

Congratulations! You have just learned how to add new data to a pre-trained model in Scikit-learn! You can use either the warm_start parameter set to True or the partial_fit() method. However, not all the models in the Scikit-learn library provide the possibility to add new data to a pre-trained model. Thus my suggestion is to check the documentation!

You can download the code used in this tutorial from my Github repository.

If you have read this far, for me it is already a lot for today. Thanks! You can read my trending articles at this link.

Related Articles

Stay connected!

Women In Tech
Artificial Intelligence
Machine Learning
Python
Scikit Learn
Recommended from ReadMedium