Cross-Validation

Learn about K-Fold cross-validation and why it's used.

Chapter Goals:

  • Learn about the purpose of cross-validation
  • Implement a function that applies the K-Fold cross-validation algorithm to a model

A. Additional evaluation datasets

Sometimes, it's not enough to just have a single testing set for model evaluation. Having additional sets of data for evaluation gives us a more accurate measurement of how good the model is for the original dataset.

If the original dataset is big enough, we can actually split it into three subsets: training, testing, and validation. The validation set is about the same size as the testing set, and it is used for evaluating the model after training. The testing set is then used for final evaluation once the model is done training and tuning.

However, partitioning the original dataset into three distinct sets will cut into the size of the training set. This can reduce the performance of the model if our original dataset is not large enough. A solution to this problem is cross-validation (CV).

Cross-validation creates synthetic validation sets by partitioning the training set into multiple smaller subsets. One of the most common algorithms for cross-validation, K-Fold CV, partitions the training set into k approximately equal sized subsets (referred to as folds). There are k "rounds" of the algorithm, and each "round" chooses one of the k subsets for the validation set (a different subset is chosen each round), while the remaining k - 1 subsets are aggregated into the round's training set and used to train the model.

Each round of the K-Fold algorithm, the model is trained on that round's training set (the combined training folds) and then evaluated on the single validation fold. The evaluation metric depends on the model. For classification models, this is usually classification accuracy on the validation set. For regression models, this can either be the model's mean squared error, mean absolute error, or R2 value on the validation set.

B. Scored cross-validation

In scikit-learn, we can easily implement K-Fold cross-validation with the cross_val_score function (also part of the model_selection module). The function returns an array containing the evaluation score for each round.

The code below demonstrates K-Fold CV with 3 folds for classification. The evaluation metric is classification accuracy.

Press + to interact
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
# We can skip max_iter argument here, but it will produce a
# ConvergenceWarning. Therefore we explicity give a bigger value to
# avoid the warning.
clf = linear_model.LogisticRegression(max_iter=3000)
# Predefined data and labels
cv_score = cross_val_score(clf, data, labels, cv=3) # k = 3
print('{}\n'.format(repr(cv_score)))

The code below demonstrates K-Fold CV with 4 folds for regression. The evaluation metric is R2 value.

Press + to interact
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
reg = linear_model.LinearRegression()
# Predefined data and labels
cv_score = cross_val_score(reg, data, labels, cv=4) # k = 4
print('{}\n'.format(repr(cv_score)))

Note that we don't call fit with the model prior to using cross_val_score. This is because the cross_val_score function will use fit for training the model each round.

For classification models, the cross_val_score function will apply a special form of the K-Fold algorithm called stratified K-Fold. This just means that each fold will contain approximately the same class distribution as the original dataset. For example, if the original dataset contained 60% class 0 data observations and 40% class 1, each fold of the stratified K-Fold algorithm will have about the same 60-40 split between class 0 and class 1 data observations.

While cross-validation gives us a better measurement of the model's fit on the original dataset, it can be very time-consuming when used on large datasets. For large enough datasets, it is better to just split it into training, validation, and testing sets, and then use the validation set for evaluating the model before it is finalized.

Get hands-on with 1300+ tech skills courses.