I'm training and evaluating a logistic regression and a XGBoost classifier.
With the XGBoost classifier, a training/validation/test split of the data and the subsequent training and validation shows the model is overfitting the training data. So, I'm working with k-fold cross-validation to reduce overfitting.
To work with k-fold cross-validation, I'm splitting my data into training and test sets and performing the k-fold cross-validation on the training set. The code looks something like the following:
model = XGBClassifier()
kfold = StratifiedKFold(n_splits = 10)
results = cross_val_score(model, x_train, y_train, cv = kfold)
The code works. Now, I've read several forums and blogs on how to make predictions after a k-fold cross-validation, but after these readings, I'm still not sure about the proper way of doing the predictions.
It would seem that using the cross_val_predict() method from sklearn.model_selection and using the test set is OK. The code would look something like the following:
y_pred = cross_val_predict(model, x_test, y_test, cv = kfold)
The code works, but the issue is whether this makes sense since I've seen more complicated ways of doing so and where it doesn't seem clear whether the training or the test set should be used for the predictions.
And if this makes sense, computing the accuracy score and the confusion matrix would be as simple as running something like the following:
accuracy = metrics.accuracy_score(y_test, y_pred)
cm = metrics.confusion_matrix(y_test, y_pred)
These two would help compare the logistic regression and the XGBoost classifier. Does this way of making predictions and evaluating models make sense?
Any help is appreciated! Thanks!
I want to answer this question I posted myself by summarizing things I have read and tried.
First, I want to clarify that the idea behind splitting my data into training/test sets and performing the k-fold cross-validation on the training set is to reserve the test set for providing a generalization error in much the same way we split data into training/validation/test sets and use the test set for providing a generalization error. For the sake of clarity, let me split the discussion into 2 sections.
Section 1
Now, reading more stuff, it's clearer to me
cross_val_predict()returns the predictions that were obtained during the cross-validation when the elements were in a test set (see section 3.1.1.2 in this scikit-learn cross-validation doc). This test set refers to one of the test sets the cross-validation procedure internally creates (cross-validation creates a test set in each fold). Thus:returns the predictions from the cross-validation internal test sets. It then seems safe to obtain the accuracy and confusion matrix with:
While
cross_val_predict(model, x_test, y_test, cv = kfold)runs, it seems doing this doesn't make much sense.Section 2
From some blogs that talk about creating a confusion matrix after a cross-validation procedure (see here and here), I borrowed code that, for each fold of the cross-validation, extracts the labels and predictions from the internal test set. These labels and predictions are later used to compute the confusion matrix. Assuming I store the labels and predictions in variables called
actual_classesandpredicted_classes, respectively, I then run:The results are exactly the same as the ones from Section 1's equivalent code. This reinforces that
cross_val_predict(model, x_train, y_train, cv = kfold)works fine.Thus:
cross_val_predict()to make predictions with unseen data in k-fold cross-validation? I would say No, it doesn't sincecross_val_predict()makes predictions with the internal test sets from the cross-validation procedure. It seems that to make predictions with unseen data and compute a generalization error we would need a way to extract one of the models from the cross-validation procedure (e.g., see this question)cross_val_predict()to compare models? I would say Yes, it does as long as the method is executed as shown in Section 1. The accuracy and confusion matrix could be used to make comparisons against other models.Any comment is appreciated! Thanks!