Transfer learning in Pytorch overfitting only during all stage

48 Views Asked by At

I am working on finetuning a CNN based multi-label classification model for a research project. This model was built and tweaked by several folks before me, and showed strong results based on the published results of the last iteration. All of our models were built using Pytorch's transfer learning tools.

However, now that I have begun working on finetuning it for a slightly different use case, I am seeing a big drop in performance. After spending time sweeping the project for errors, I found one that was causing the loss per sample to accumulate improperly which had disguised overfitting in my loss curves - but the model only begins overfitting when it enters the all stage when all the parameters are unfrozen. Loss per sample curve for my most recent resnet run, using the basic parameters described in prior work. Blue and grey lines are top validation and training loss, and pink and orange are all training and validation loss

I'm at a loss as to why this is happening. I've double checked the metrics trackers, the training loop, and anywhere else I could think of an issue coming up but no luck. This is especially puzzling given that I am using the same architecture as my predecessors and they had no problems with overfitting like this, and I've ADDED data to the training set so I would've though the model would be less prone to overfit.

It's possible there's a mistake in the code I'm missing - the model is built using several files some of which i've ommited for space, but I wrote the following code snippets to reflect the broad structure: Here's the Trainer object:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, models 
from modeling.backbone.resnet import ResNet101 
from RN101_newtop import RN101_newtop
from PerformanceMetrics import PerformanceMetrics 

#tensorboard implementation
from torch.utils.tensorboard import SummaryWriter

class Trainer(object): 
    def __init__(self, train_params, modelname="resnet"):
        self.batch_size = {} #establish dictionary of params for batch size, epoch num, learning rate
        self.epochs = {}
        self.lr = {}
        self.dataloaders = {'top': {}, 'all': {} } #need dataloaders for top & all training
        self.sig_thresholds = {}   #probability thresholds to assess model performance on
        self.model = []
        self.criterion = []
        self.optimizable_parameters = []

        self.batch_size['top'] = train_params['batch_size_top'] #define top vs all in batch_size dict
        self.batch_size['all'] = train_params['batch_size_all']
        self.epochs['top'] = train_params['epochs_top'] #30 number of epochs to train the top of the model
        self.epochs['all'] = train_params['epochs_all'] #20 number of epochs to train the entire model
        self.lr['top'] = 1e-4 #use a slightly faster learning rate for top
        self.lr['all'] = 1e-5 #previously 1e-5 %%

        self.modelname = modelname
        self.N_CLASSES = 7 #establish number of classes

        self.sig_thresholds['train'] = [0.5] #within threshold dictionary, diff vals for training/validation
        self.sig_thresholds['val'] = [0.3, 0.5, 0.7, 0.9] #???

        self.log_file = open("log_file.txt", "a") #opens the logfile to append # %%
        self.model_path_base = './modeling/saved_models/' #creates a path for this model # %%

        self.best_model = [] #space for storing best model weights, scores and thresholds
        self.best_score = 0.0
        self.best_sigthresh = 0.0

###
# DataLoader Setup
###
    def setup_dataloaders(self, dataset_dict, bShuffle ,num_workers, samplers): #args to be provided in training.py
        for stage in ['top','all']: #loop over stages
            for phase in ['train','val']: #loop over phases
                 self.dataloaders[stage][phase] = torch.utils.data.DataLoader(dataset_dict[phase], batch_size=self.batch_size[stage], shuffle = bShuffle, num_workers = num_workers)



###
# Model Setup
###

    def setup_model(self, distributed):
            if(self.modelname=="resnet"): 
                pretrained_model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) #create pretrained model
                resnet_bottom = torch.nn.Sequential(*list(pretrained_model.children())[:-1]) # remove last layer (fc) layer
                model = RN101_newtop(base_model = resnet_bottom, num_classes = self.N_CLASSES)
                    

    def count_optimizable_parameters(self): #counts all the parameters for which grad is being tracked
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad) #numel returns length of input tensor

    def set_optimizable_parameters(self, stage): #function for setting parameter tracking states
        if stage == 'top': #when doing transfer learning:
            for param in self.model.parameters():
                param.requires_grad = False #turn off gradient tracking of all params (parameters will not be updated)
            if(self.modelname=='densenet' or self.modelname=='dpn'or self.modelname=='neat' or  self.modelname=='neater'):
                params_to_optimize_in_top = list(self.model.classifier.parameters())
            else:
                params_to_optimize_in_top = list(self.model.fc.parameters()) #create a list of parameters in fully
            # convolutional layers

            for param in params_to_optimize_in_top:
                    param.requires_grad = True #set those parameters to be updated
                #Goal here is to keep the layers of pre_trained that can detect basic features like edges, while
                # training the layer responsible for making predictions of class
            self.optimizable_parameters = params_to_optimize_in_top

        if stage == 'all': #during all stage:
            for param in self.model.parameters():
                param.requires_grad = True #all parameters are tracked and updated

            self.optimizable_parameters = self.model.parameters()

    def evaluate_batch(self, input, target, phase, sig_thresh): 
        with torch.set_grad_enabled(phase =='train'):  #track gradients for backprop in training phase
            sigfunc = nn.Sigmoid() #create sigmoid function
            
                output = self.model(input)  # pass in image series (do forward prop)
                loss = self.criterion(output, target)  # evaluate loss against the training set annotation

            #make predictions
            sig = sigfunc(output)
            sig = sig.to("cpu").detach().numpy()
            pred = sig > sig_thresh
            pred = pred.astype(int)

        return pred, loss #print results

    def evaluate_epoch(self, phase, optimizer, dataloader, sig_thresh,epoch,ID,stage,model_path,bs=0): #pulls args
        metrics = PerformanceMetrics() #metrics function
        best_score = bs
        

        for it, batch in enumerate(dataloader): #iterate over dataloader object
            input  = batch['X'].cuda()#to(device)
            target = batch['Y'].cuda()#to(device)

            pred, loss = self.evaluate_batch(input, target, phase, sig_thresh) #use prior function to evaluate batch


            if phase == 'train':
                loss.backward()  # update the gradients
                optimizer.step()  # update sgd optimizer lr
                optimizer.zero_grad()  # zero gradients

            # compare model output and annotations  - more easily done with numpy
            target = target.to("cpu").detach().numpy()  #take off gpu, detach from gradients
            target = target.astype(int)

            n_samples =  input.size(0)
            metrics.accumulate(loss.item(), n_samples, pred, target) 

        logdir = 'ALLruns/{}/{}_{}{}_sig_thresh_{:.2f}'.format(ID, ID, stage, phase, sig_thresh)
        writer = SummaryWriter(log_dir=logdir)
        if phase == 'train':
            writer.add_scalar("Loss_per/epoch", metrics.loss_per_sample, epoch)
            # writer.add_scalar("Loss_it/train_epoch{}".format(epoch), loss, it)
            writer.add_scalar("F1/epoch", metrics.f1, epoch)
        if phase == 'val':
            writer.add_scalar("Loss_per/epoch", metrics.loss_per_sample, epoch)
            # writer.add_scalar("Loss_it/val_epoch{}".format(epoch), loss, it)
            writer.add_scalar("F1/epoch", metrics.f1, epoch)


        return metrics

    #define function for training. Requires args for stage/criterion/optimizer
    def train(self, stage, criterion, optimizer, scheduler = None,  results = {'best_score': 0.0, 'best_model': []}, ID = None, model_path = None ):
        self.criterion = criterion 
        
        
        best_score = results['best_score'] #define best_score

        for epoch in range(self.epochs[stage]): #loop over num of epochs for a given stage
            for phase in ['train','val']: #loop over phase
                if phase == 'train':
                    self.model.train() #puts model in training mode
                else:
                    self.model.eval() #evaluation mode

                for sig_thresh in self.sig_thresholds[phase]: #print useful info for given iteration
                    
                    #use evaluate_epoch
                    metrics = self.evaluate_epoch(phase, optimizer, self.dataloaders[stage][phase], sig_thresh,epoch=epoch, ID=ID,stage = stage,model_path =model_path,bs = self.best_score)




                    # save model if it is the best model on val set
                    if phase == 'val' and metrics.f1 > best_score:  # when a new high score is reached:
                        self.best_score = metrics.f1  # replace former best score
                        best_score = metrics.f1
                        self.best_sigthresh = sig_thresh  # record sigma threshold
                        full_model_path = '{}_{}_sig_{:.2f}.torch'.format(self.model_path_base, model_path,
                                                                          sig_thresh)  # record model path
                        self.best_model = full_model_path  # and store
                        torch.save(self.model, full_model_path)  # save the model
                        

                    print('{} Loss: {:.4f}, sig_thresh: {:.2f}, F1_score: {:.4f}, precision: {:.4f}, recall: {:.4f}'.format(phase, sig_thresh, metrics.loss_per_sample, metrics.f1, metrics.precision, metrics.recall))

                    self.log_file.write('epoch\t{}\tphase\t{}\tsig_thresh:\t{:.2f}\tLoss\t{:.4f}\tPrecision\t{:.4f}\n'.format(epoch, phase,  sig_thresh, metrics.loss_per_sample, metrics.f1))



                if scheduler: 
                    scheduler.step()
                    print("Learning rate :")
                    print(scheduler.get_last_lr())



        return {'best_score': self.best_score, 'best_model': self.best_model, 'sig_thresh': self.best_sigthresh,'accuracy': acc}

Here's the metrics object:

import numpy as np

class PerformanceMetrics:  #this class is only appropriate for prescence/absence right now
    def __init__(self):
        #raw data - accumulated as batches are processed
        self.loss = 0.0
        self.true_pos = 0
        self.true_neg = 0
        self.false_pos = 0
        self.false_neg = 0
        self.positives = 0
        self.negatives = 0
        self.samples = 0

        #metrics
        self.precision = 0.0
        self.recall = 0.0
        self.f1 = 0.0
        self.true_pos_rate = 0.0
        self.true_neg_rate = 0.0
        self.false_pos_rate = 0.0
        self.false_neg_rate = 0.0
        self.loss_per_sample = 0.0

    def accumulate(self, loss, batch_size, pred, target):
        corr = np.equal(target, pred)
        tp = np.where(target == 1, corr, False )  #true positives
        tn = np.where(target == 0, corr, False )  #true negatives
        fp = np.where(target == 0, np.logical_not(corr), False)  #false positives
        fn = np.where(target == 1, np.logical_not(corr), False)  # false negatives

        self.loss += loss
        self.true_pos  += np.sum(tp)
        self.true_neg  += np.sum(tn)
        self.false_pos += np.sum(fp)
        self.false_neg += np.sum(fn)
        self.positives  += np.sum(target)
        self.negatives += np.sum(np.logical_not(target))
        self.samples += batch_size

        self.precision = self.true_pos / (self.true_pos + self.false_pos)
        self.recall    = self.true_pos / (self.true_pos + self.false_neg)
        self.f1 = 2*(self.precision * self.recall)/(self.precision + self.recall)
        self.true_pos_rate = self.true_pos / self.positives
        self.true_neg_rate = self.true_neg / self.negatives
        self.false_pos_rate = self.false_pos /self.negatives
        self.false_neg_rate = self.false_neg /self.positives
        self.loss_per_sample = self.loss / self.samples

and finally here is the script for actually running the training:

import torch.nn as nn
import torch.utils.data.distributed
import torch.distributed as dist
from torchvision import transforms
import PIL
import yaml #library for dealing with .yaml files, which are the config file formats
import numpy as np
from Trainer import Trainer # %%
from marsh_plant_dataset import MarshPlant_Dataset 



#establish a few parameters
image_dim = (512, 512)
crop_dim = (1000, 1000)


ymldata = yaml.load(ymlfile, Loader=yaml.FullLoader) #read in .yaml config file

# pull info from config file on model, data files, training params, etc
modelname = ymldata["model"]
datafiles = {'train' : ymldata["datafiles"]["train"], #['small_pa_sample.txt'],
                 'val'   : ymldata["datafiles"]["val"],
                 'test'  : ymldata["datafiles"]["test"]
             } 

train_params = { #establish dictionary for  batch size params for top and all
        'batch_size_top': ymldata["batch_size"]["top"],
        'batch_size_all': ymldata["batch_size"]["all"],
        'epochs_top': ymldata["epochs"]["top"],
        'epochs_all': ymldata["epochs"]["all"]
} 
do_data_aug = ymldata["data_aug"] #carries out data augmentation if it is specified


    #initialize the Trainer class - see Trainer script for further details
trainer = Trainer(train_params, modelname=modelname)


    #setup data transform
transforms_base = transforms.Compose([ #resize image chunks, transform to tensor, and normalize
        transforms.Resize(image_dim),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    #in the case of data augmentation
if do_data_aug:
    print("doing data aug")
    transform_train = transforms.Compose([
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(hue=.02, saturation=.02),
            transforms.RandomAffine(20, translate=None, scale = (0.8, 1.1), shear = 10, fill=0),
            transforms.CenterCrop(crop_dim),
            transforms_base
        ])
else: #if not, just use base transforms
    transform_train = transforms_base
    #inherit base transforms to testing and validation sets 
transform_test = transforms_base
transform_val = transforms_base

    #load datasets and setup Datasets using specified parts of datafiles dict and transforms
train_data = MarshPlant_Dataset(datafiles['train'], train=True, transform=transform_train)
val_data   = MarshPlant_Dataset(datafiles['val'], train=True, transform=transform_val)
test_data  = MarshPlant_Dataset(datafiles['test'], train=True, transform=transform_test)
datasets = {'train': train_data, 'val': val_data}

    #establish parameters for batch shuffling and number of GPUs to use
bShuffle = True
num_workers = 8
    # universal ID
id = ymldata["id"]

with torch.cuda.device(device='cuda')
   # establish dataloaders using Trainer methods
   trainer.setup_dataloaders(datasets, bShuffle, num_workers, samplers=samplers)
   trainer.setup_model(distributed)  # setup model and training parameters

        # establish results dictionary and gamma dictionary
   results = {'best_score': 0.0, 'best_model': []}

        #set up gamma values (these are optimized values for Limonium performance!)
   gamma_top = 0.5
   gamma_all = 0.8
   gamma = {'top': gamma_top, 'all': gamma_all}  # all previously 0.8 %%

   modelpath = modelname + '_ID{}'.format(id)
   for stage in ['top', 'all']:  # loop through top (transfer learning) and whole model training stages
        trainer.set_optimizable_parameters(stage)  # set parameter tracking according to stage
        print("Training {}: {} parameters optimized".format(stage, trainer.count_optimizable_parameters()))
            
        optimizer = torch.optim.Adam(trainer.optimizable_parameters, lr=trainer.lr[stage])  # use Adam optimizer
            
        criterion = nn.BCEWithLogitsLoss().cuda()  # use Binary Cross Entropy w Logits Loss
            # this allows for non-mutually exclusive classification (multiple plants in an image)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=gamma[stage])
            # learning rate scheduler
        results = trainer.train(stage, criterion, optimizer, scheduler=lr_scheduler, results=results,
                                    ID=id ,model_path=modelpath)
            # pass each of the defined arguments into .train method
        print('Finished training {}, best acc {:.4f}'.format(stage, results['best_score']))
            # prit current stage, best results so far

  evaluator = Evaluator(results['best_model'], sig_thresh=results['sig_thresh'], modelname=modelname,
                              transform=transform_test, config_file=config,ID = id)
        # instantiate Evaluator on the results, pass it info on sigma threshold, transforms, model configuration
  evaluator.setup_dataloader(test_data)  # use steup_dataloader Evaluator method
  evaluator.run()  # this accomplishes much of Evaluator's tasks:

I've tried using simpler architectures (going from resnext101 to resnet50) and data augmentation to see if it alleviated the problem, but the val loss still increases during the all stage. Should I try reducing trainable parameters more? Or refrain from unfreezing the bottom parameters at all?

Thank you so much for reading all this and any help given!

0

There are 0 best solutions below