A More Convenient Approach to Cross-Validation
Learn to use the GridSearchCV in scikit-learn for hyperparameter tuning.
We'll cover the following
Advantages of using GridSearchCV
In the “The Bias-Variance Trade-Off” chapter, we gained a deep understanding of cross-validation by writing our own function to do it, using the KFold
class to generate the training and testing indices. This was helpful to get a thorough understanding of how the process works. However, scikit-learn offers a convenient class that can do more of the heavy lifting for us: GridSearchCV
. The GridSearchCV
class can take as input a model that we want to find optimal hyperparameters for, such as a decision tree or a logistic regression, and a “grid” of hyperparameters that we want to perform cross-validation over. For example, in a logistic regression, we may want to get the average cross-validation score over all the folds for different values of the regularization parameter, C
. With decision trees, we may want to explore different depths of trees.
You can also search multiple parameters at once, for example, if we wanted to try different depths of trees and different numbers of max_features
to consider at each node split.
GridSearchCV
does what is called an exhaustive grid search over all the possible combinations of parameters that we supply. This means that if we supplied five different values for each of the two hyperparameters, the cross-validation procedure would be run 5 x 5 = 25 times. If you are searching many values of many hyperparameters, the number of cross-validation runs can grow very quickly. In these cases, you may wish to use RandomizedSearchCV
, which searches a random subset of hyperparameter combinations from the universe of all possibilities in the grid you supply.
GridSearchCV
can speed up your work by streamlining the cross-validation process. You should be familiar with the concepts of cross-validation from the previous chapter, so we proceed directly to listing all the options available for GridSearchCV
.
Get hands-on with 1200+ tech skills courses.