use k-fold cross validation with pytorch

59 Views Asked by At

I am training to add k-fold cross validation to my script ,after reading some documentations it says that the training loop should be inside the fold loop but what I didn't understande is the that the dataloaders should be inside the the fold loop also but in my case it's not so If I want to use the dataloaders defined outside the fold loop and call them from the inside how can I do that ? these are the functions def get_train_utils(opt, model_parameters):

            data augmentation 
             ...........
   train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=opt.batch_size,
                                           shuffle=(train_sampler is None),
                                           num_workers=opt.n_threads,
                                           pin_memory=True,
                                           sampler=train_sampler,
                                           worker_init_fn=worker_init_fn)
return return (train_loader, train_sampler, train_logger, train_batch_logger,
        optimizer, scheduler)

and

def get_val_utils(opt):
   data augmentation 
   ........
   val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=(opt.batch_size //
                                                     opt.n_val_samples),
                                         shuffle=False,
                                         num_workers=opt.n_threads,
                                         pin_memory=True,
                                         sampler=val_sampler,
                                         worker_init_fn=worker_init_fn,
                                         collate_fn=collate_fn)
 return val_loader, val_logger

and the training and validation loop are defined in another function

def main_worker(index, opt):
 other code 
     
if not opt.no_train:
    (train_loader, train_sampler, train_logger, train_batch_logger,
     optimizer, scheduler) = get_train_utils(opt, parameters)
    if opt.resume_path is not None:
        opt.begin_epoch, optimizer, scheduler = resume_train_utils(
            opt.resume_path, opt.begin_epoch, optimizer, scheduler)
        if opt.overwrite_milestones:
            scheduler.milestones = opt.multistep_milestones
if not opt.no_val:
    val_loader, val_logger = get_val_utils(opt)

if opt.tensorboard and opt.is_master_node:
    from torch.utils.tensorboard import SummaryWriter
    if opt.begin_epoch == 1:
        tb_writer = SummaryWriter(log_dir=opt.result_path)
    else:
        tb_writer = SummaryWriter(log_dir=opt.result_path,
                                  purge_step=opt.begin_epoch)
else:
    tb_writer = None

prev_val_loss = None

   for i in range(opt.begin_epoch, opt.n_epochs + 1):
    if not opt.no_train:
        if opt.distributed:
            train_sampler.set_epoch(i)
        current_lr = get_lr(optimizer)
        train_epoch(i, train_loader, model, criterion, optimizer,# 
                    opt.device, current_lr, train_logger,
                    train_batch_logger, tb_writer, opt.distributed)
      
        if i % opt.checkpoint == 0 and opt.is_master_node:
            save_file_path = opt.result_path / 'save_{}.pth'.format(i)
            save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
                            scheduler)


    if not opt.no_val:
      prev_val_loss = val_epoch(i, val_loader, model, criterion,#
                                  opt.device, val_logger, tb_writer,
                                  opt.distributed)
    
    if not opt.no_train and opt.lr_scheduler == 'multistep':
        scheduler.step()
    elif not opt.no_train and opt.lr_scheduler == 'plateau':
        scheduler.step(prev_val_loss)
1

There are 1 best solutions below

1
Karl On

The dataloader is created from the dataset, which is created by the k-fold split.

What you're asking for - I want to use the dataloaders defined outside the fold loop - doesn't make sense. The dataloaders have a fixed dataset split. Using k-fold requires you to create different splits. If you want to do k-fold cross validation, you have to create different dataset splits.

Pseudocode below:

dataset = ...

for k in range(n_folds):
    train_dataset, valid_dataset = split_dataset(dataset)

    train_dataloader = ... # create from train_dataset
    valid_dataloader = ... # create from valid_dataset

    train_epoch(train_dataloader, valid_dataloader, ...)