How to Get Optimal Models From Benchmark For Prediction on Test Data

68 Views Asked by At

This question is applicable to both mlr and mlr3 but I only included the code for mlr since I have it handy. As an example, when we have 3 folds outer CV, we get 3 sets of optimal hyperparameters (subset of selected/filtered features). Which hyperparameter set should be used for prediction? Does mlr/mlr3 use the average of value for each parameter or the parameter set that is associated with the best performance measure (lowest error or max c-index). What is the best practice for this? Same question for selected/filtered features?

Additional questions

  1. How best to get the optimized models from benchmark and use them for prediction on the test set given I set (keep.extract = TRUE & models = TRUE) in the benchmark.
  2. How get filtered feature of Cox2 model (getBMRFilteredFeatures returns NULL results)
  3. How to get the importance features from the final models from benchmark? do I get them after I predict on the test set?
library(tidyverse)
library(tidymodels)
library(PKPDmisc)
library(mlr)
library(parallelMap)
library(survival)


# Data and Data Splitting  
data = as_tibble(lung) %>% 
       na.omit()      %>% 
       mutate(status = if_else(status==1, 0, 1))

set.seed(123)
split <- data  %>% initial_split(prop = 0.8, strata = status)    
train <- split %>% training()
test  <- split %>% testing()


# Task
task = makeSurvTask(id = "Survival", data = train, target = c("time", "status")) 

 
# Resample 
# For model assessment before external validation on test data
set.seed(123)
outer_cv  = makeResampleDesc("CV", iter=3, stratify.cols = c("status")) %>% 
 makeResampleInstance(task) 

# For feature selection and parameter tuning
set.seed(123)
inner_cv  = makeResampleDesc("CV", iter=3, stratify.cols = c("status"))


# Learners
cox1      = makeLearner(id = "COX1", "surv.coxph") %>%
           makeFeatSelWrapper(resampling = inner_cv, show.info = TRUE,
                              control = makeFeatSelControlSequential(method = "sffs"))
 
cox2      = makeLearner(id = "COX2", "surv.coxph") %>%
           makeFilterWrapper(fw.method = "univariate.model.score") %>% 
           makeTuneWrapper(resampling  = inner_cv, show.info = TRUE, 
                           par.set     = makeParamSet(makeIntegerParam("fw.abs", lower = 2, upper = 10)), 
                           control     = makeTuneControlGrid(resolution = 5L)) 
 
cox_lasso = makeLearner(id = "COX LASSO", "surv.glmnet") %>% 
           makeTuneWrapper(resampling = inner_cv, show.info = TRUE, 
                           par.set    =  makeParamSet(makeNumericParam("lambda",lower = -3, upper = 0.5,
                                                                       trafo = function(x) 10^x)),
                           control    = makeTuneControlGrid(resolution = 5L))
 
cox_net   = makeLearner(id = "COX NET",  "surv.glmnet") %>% 
           makeTuneWrapper(resampling = inner_cv, show.info = TRUE, 
                           par.set    =  makeParamSet(makeNumericParam("alpha", lower = 0,  upper = 1,
                                                                       trafo = function(x) round(x,2)),
                                                      makeNumericParam("lambda",lower = -3, upper = 0.5,
                                                                       trafo = function(x) 10^x)),
                           control    = makeTuneControlGrid(resolution = 5L))
 
rsf       = makeLearner(id = "RSF", "surv.randomForestSRC") %>% 
           makeTuneWrapper(resampling = inner_cv, show.info = T,
                           par.set    = makeParamSet(makeDiscreteParam("ntree", values = c(500,1000,2000)),
                                                     makeIntegerParam("mtry",     lower = 2,  upper = 5),
                                                     makeIntegerParam("nodesize", lower = 2,  upper = 5),
                                                     makeIntegerParam("nsplit",   lower = 2,  upper = 5)),
                           control    = makeTuneControlGrid(resolution = 5L))


# Benchmark 
parallelStartSocket(4)
set.seed(123, "L'Ecuyer") # for reproducible results with parallel computing 
start_time <- Sys.time() 

bmr       = benchmark(learners    = list(cox1, cox2, cox_lasso, cox_net, rsf),
                     tasks       = task,
                     resamplings = outer_cv,
                     keep.extract= TRUE, 
                     models      = TRUE,
                     show.info   = TRUE)

end_time <- Sys.time()
end_time - start_time
parallelStop()

getBMRFeatSelResults(bmr)
getBMRFilteredFeatures(bmr)
getBMRTuneResults(bmr)
0

There are 0 best solutions below