GridSearchCV fit call: ValueError: n_splits=5 cannot be greater than the number of members in each class

119 Views Asked by At

I'm trying to call GridSearchCV fit and get error. But can't understand this error.

Input:

X_train.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 6000 entries, 9761 to 7270
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype
---  ------  --------------  -----
 0   a       6000 non-null   int64
 1   b       6000 non-null   int64
 2   c       6000 non-null   int64
 3   d       6000 non-null   int64
dtypes: int64(4)
memory usage: 234.4 KB

y_train.info()

<class 'pandas.core.series.Series'>
Int64Index: 6000 entries, 9761 to 7270
Series name: result
Non-Null Count  Dtype
--------------  -----
6000 non-null   int64
dtypes: int64(1)
memory usage: 93.8 KB

GridSearchCV call:

parameters = {
    "max_depth": [1, 2, 3],
}

cv = GridSearchCV(
    DecisionTreeClassifier(), 
    parameters, 
    cv=5,
    verbose=1,
)

cv.fit(X_train, y_train)

And I receive error:

ValueError: n_splits=5 cannot be greater than the number of members in each class.

Why? Please explain this error. Thanks.

1

There are 1 best solutions below

2
raywib On

With cv=5 you ask for a five-fold stratified split to be performed. This means the percentage of each class is preserved for each split. The smallest class in your data appears to have less than 5 members so a stratified split is not possible: You need at least 5 members in each class so that at least one can be assigned to every split. You can take a look at the class distribution with y_train.value_counts().

A quick solution would be to reduce the number of folds, e.g.

cv = GridSearchCV(
    DecisionTreeClassifier(), 
    parameters, 
    cv=2,
    verbose=1,
)

Alternatively you could perform a non-stratified five-fold split like so

from sklearn.model_selection import KFold

kf5 = KFold(n_splits=5)
cv = GridSearchCV(
    DecisionTreeClassifier(), 
    parameters, 
    cv=kf5,
    verbose=1,
)

For more details on stratified splitting and other forms of cross validation see the sklearn user guide.