Pytorch lightning on_save_checkpoint is not called

52 Views Asked by At

on my project(which is in development phase) (if u want to take a look/clone make sure that it's "on_save_checkpointNotWorking" branch)

I have a class named BrazingTorch (in brazingTorchFolder/brazingTorch.py) inherited from pytorch lightning LightningModule and I have defined on_save_checkpoint method on this class (in brazingTorchFolder/brazingTorchParents/saveLoad.py) but it's not called in .fit method (which as probably you have guessed is for running whole pipline including training_step and etc)

note I am using ModelCheckpoint in .fit, even though it actually saves the model, on_save_checkpoint is not called!!

I have checked the fact that on_save_checkpoint should be on BrazingTorch (the class which inherits from LightningModule), and it's correct and it exists on model. also have asked this problem from gemini or chatgpt and they provided some things to check which were trivial but then again correctly implemented.

you may try .fit method in tests\brazingTorchTests\fitTests.py. there I am calling .fit method which is in brazingTorchFolder\brazingTorchParents\modelFitter.py (note it's closely related to .baseFit in the same file)

note the loggerPath and where the checkpoint is saved is in tests\brazingTorchTests\NNDummy1\arch1\mainRun_seed71

def fit(self, trainDataloader: DataLoader,
      valDataloader: Optional[DataLoader] = None,
      *, lossFuncs: List[nn.modules.loss._Loss],
      seed=None, resume=True, seedSensitive=False,
      addDefaultLogger=True, addDefault_gradientClipping=True,
      preRunTests_force=False, preRunTests_seedSensitive=False,
      preRunTests_lrsToFindBest=None,
      preRunTests_batchSizesToFindBest=None,
      preRunTests_fastDevRunKwargs=None, preRunTests_overfitBatchesKwargs=None,
      preRunTests_profilerKwargs=None, preRunTests_findBestLearningRateKwargs=None,
      preRunTests_findBestBatchSizesKwargs=None,
      **kwargs):

    if not seed:
      seed = self.seed

    self._setLossFuncs_ifNot(lossFuncs)

    architectureName, loggerPath, shouldRun_preRunTests = self._determineShouldRun_preRunTests(
      False, seedSensitive)


    loggerPath = loggerPath.replace('preRunTests', 'mainRun_seed71')

    checkpointCallback = ModelCheckpoint(
      monitor=f"{self._getLossName('val', self.lossFuncs[0])}",
      mode='min', # Save the model when the monitored quantity is minimized
      save_top_k=1, # Save the top model based on the monitored quantity
      every_n_epochs=1, # Checkpoint every 1 epoch
      dirpath=loggerPath, # Directory to save checkpoints
      filename=f'BrazingTorch',
    )
    callbacks_ = [checkpointCallback, StoreEpochData()]
    kwargsApplied = {
      'logger': pl.loggers.TensorBoardLogger(self.modelName, name=architectureName,
                          version='preRunTests'),
      'callbacks': callbacks_, }

    return self.baseFit(trainDataloader=trainDataloader, valDataloader=valDataloader,
              addDefaultLogger=addDefaultLogger,
              addDefault_gradientClipping=addDefault_gradientClipping,
              listOfKwargs=[kwargsApplied], **kwargs)

  @argValidator
  def baseFit(self, trainDataloader: DataLoader,
        valDataloader: Union[DataLoader, None] = None,
        addDefaultLogger=True, addDefault_gradientClipping=True,
        listOfKwargs: List[dict] = None,
        **kwargs):

    # cccUsage
    # - this method accepts kwargs related to trainer, trainer.fit, and self.log and
    # pass them accordingly
    # - the order in listOfKwargs is important
    # - _logOptions phase based values feature:
    #      - args related to self.log may be a dict with these keys 'train', 'val', 'test',
    #          'predict' or 'else'
    #      - this way u can specify what phase use what values and if not specified with
    #        'else' it's gonna know

    # put together all kwargs user wants to pass to trainer, trainer.fit, and self.log
    listOfKwargs = listOfKwargs or []
    listOfKwargs.append(kwargs)
    allUserKwargs = {}
    for kw in listOfKwargs:
      self._plKwargUpdater(allUserKwargs, kw)

    # add default logger if allowed and no logger is passes
    # because by default we are logging some metrics
    if addDefaultLogger and 'logger' not in allUserKwargs:
      allUserKwargs['logger'] = pl.loggers.TensorBoardLogger(self.modelName)
      # bugPotentialCheck1
      # shouldn't this default logger have architectureName

    appliedKwargs = self._getArgsRelated_toEachMethodSeparately(allUserKwargs)

    notAllowedArgs = ['self', 'overfit_batches', 'name', 'value']
    self._removeNotAllowedArgs(allUserKwargs, appliedKwargs, notAllowedArgs)

    self._warnForNotUsedArgs(allUserKwargs, appliedKwargs)

    # add gradient clipping by default
    if not self.noAdditionalOptions and addDefault_gradientClipping \
        and 'gradient_clip_val' not in appliedKwargs['trainer']:
      appliedKwargs['trainer']['gradient_clip_val'] = 0.1
      Warn.info('gradient_clip_val is not provided to fit;' + \
           ' so by default it is set to default "0.1"' + \
           '\nto cancel it, you may either pass noAdditionalOptions=True to model or ' + \
           'pass addDefault_gradientClipping=False to fit method.' + \
           '\nor set another value to "gradient_clip_val" in kwargs passed to fit method.')

    trainer = pl.Trainer(**appliedKwargs['trainer'])

    self._logOptions = appliedKwargs['log']

    if 'train_dataloaders' in appliedKwargs['trainerFit']:
      del appliedKwargs['trainerFit']['train_dataloaders']
    if 'val_dataloaders' in appliedKwargs['trainerFit']:
      del appliedKwargs['trainerFit']['val_dataloaders']
    trainer.fit(self, trainDataloader, valDataloader, **appliedKwargs['trainerFit'])

    self._logOptions = {}
    return trainer
  def on_save_checkpoint(self, checkpoint: dict):
    # reimplement this method to save additional information to the checkpoint

    # Add additional information to the checkpoint
    checkpoint['brazingTorch'] = {
      '_initArgs': self._initArgs,
      'allDefinitions': self.allDefinitions,
      'warnsFrom_getAllNeededDefinitions': self.warnsFrom_getAllNeededDefinitions,
    }
    return checkpoint
1

There are 1 best solutions below

0
Farhang Amaji On

on_save_checkpoint must be defined on the main class inheriting from pytorch lightning LightningModule (here BrazingTorch) and not it's parents