Decision tree with `rpart` and `caret` for a Target variable factorized in quartiles

91 Views Asked by At

The following code

library(rpart)
library(caret)
youdenSumary <- function(data, lev = NULL, model = NULL){
  if (length(lev) > 2) {
    stop(paste("Your outcome has", length(lev), "levels. The joudenSumary() function isn't appropriate."))
  }
  if (!all(levels(data[, "pred"]) == lev)) {
    stop("levels of observed and predicted data do not match")
  }
  Sens <- caret::sensitivity(data[, "pred"], data[, "obs"], lev[1]) 
  Spec <- caret::specificity(data[, "pred"], data[, "obs"], lev[2])
  j <- (Sens + Spec)/2
  out <- c(j, Spec, Sens)
  names(out) <- c("j", "Spec", "Sens")
  out
}



trctrl <- trainControl(method = "repeatedcv", number = 10, repeats = 20,
                       search = "grid",summaryFunction = youdenSumary)

classifier = train(x = training_set[, names(training_set) != "Target"],
                   y = training_set$Target,
                   method = 'rpart',
                   parms = list(split = "gini"),trControl=trctrl,
                   tuneLength = 10,metric = "j")
classifier
complexity_parameter=classifier$bestTune


folds = createFolds(dataset$Target, k = 10)
cv = lapply(folds, function(x) {
  training_fold = dataset[-x, ]
  test_fold = dataset[x, ]
  classifier = rpart(formula = Target ~ .,
                     data = training_fold,control = rpart.control(cp = complexity_parameter))
  y_pred = predict(classifier, newdata = test_fold[!(names(test_fold)%in%"Target")], type = 'class')
  # confrontiamo la variabile di target con i valori predetti
  cm = table(test_fold[, names(test_fold)%in%"Target"], y_pred)
  accuracy = (cm[1,1] + cm[2,2]) / (cm[1,1] + cm[2,2] + cm[1,2] + cm[2,1])
  sensitivity = cm[1,1] / (cm[1,1] + cm[2,1])
  specificity = cm[2,2] / (cm[1,2] + cm[2,2])
  df = data.frame(accuracy = accuracy, sensitivity=sensitivity,
                  specificity=specificity)
  return(df)
})
accuracy = Reduce("+", lapply(cv, "[[", 1))/10
sensitivity = Reduce("+", lapply(cv, "[[", 2))/10
specificity = Reduce("+", lapply(cv, "[[", 3))/10
balanced_accuracy=(sensitivity+specificity)/2

performs a Grid Search to find the best parameter. Suppose that method = "repeatedcv", number = 10 and repeats = 3, then three separate 10-fold cross-validations are used as the resampling scheme.

Then we apply again the cross validation to obtain accuracy, sensitivity and so on.

This code is designed for a binary target variable. How can I adapt this code for a non-binary target variable, for example a target variable divided into quartiles (i.e. 1, 2, 3, 4)?

1

There are 1 best solutions below

1
Ash Sandhu On

I think you might need to rewrite a lot of this.

The function only works as the caret spec and sens functions are comparing a binary outcome. This won't work in the multinomial case. I would check out this great package https://github.com/WandeRum/multiROC which goes into more detail on this.