on my project(which is in development phase) (if u want to take a look/clone make sure that it's "on_save_checkpointNotWorking" branch)
I have a class named BrazingTorch (in brazingTorchFolder/brazingTorch.py) inherited from pytorch lightning LightningModule and I have defined on_save_checkpoint method on this class (in brazingTorchFolder/brazingTorchParents/saveLoad.py) but it's not called in .fit method (which as probably you have guessed is for running whole pipline including training_step and etc)
note I am using ModelCheckpoint in .fit, even though it actually saves the model, on_save_checkpoint is not called!!
I have checked the fact that on_save_checkpoint should be on BrazingTorch (the class which inherits from LightningModule), and it's correct and it exists on model. also have asked this problem from gemini or chatgpt and they provided some things to check which were trivial but then again correctly implemented.
you may try .fit method in tests\brazingTorchTests\fitTests.py. there I am calling .fit method which is in brazingTorchFolder\brazingTorchParents\modelFitter.py (note it's closely related to .baseFit in the same file)
note the loggerPath and where the checkpoint is saved is in tests\brazingTorchTests\NNDummy1\arch1\mainRun_seed71
def fit(self, trainDataloader: DataLoader,
valDataloader: Optional[DataLoader] = None,
*, lossFuncs: List[nn.modules.loss._Loss],
seed=None, resume=True, seedSensitive=False,
addDefaultLogger=True, addDefault_gradientClipping=True,
preRunTests_force=False, preRunTests_seedSensitive=False,
preRunTests_lrsToFindBest=None,
preRunTests_batchSizesToFindBest=None,
preRunTests_fastDevRunKwargs=None, preRunTests_overfitBatchesKwargs=None,
preRunTests_profilerKwargs=None, preRunTests_findBestLearningRateKwargs=None,
preRunTests_findBestBatchSizesKwargs=None,
**kwargs):
if not seed:
seed = self.seed
self._setLossFuncs_ifNot(lossFuncs)
architectureName, loggerPath, shouldRun_preRunTests = self._determineShouldRun_preRunTests(
False, seedSensitive)
loggerPath = loggerPath.replace('preRunTests', 'mainRun_seed71')
checkpointCallback = ModelCheckpoint(
monitor=f"{self._getLossName('val', self.lossFuncs[0])}",
mode='min', # Save the model when the monitored quantity is minimized
save_top_k=1, # Save the top model based on the monitored quantity
every_n_epochs=1, # Checkpoint every 1 epoch
dirpath=loggerPath, # Directory to save checkpoints
filename=f'BrazingTorch',
)
callbacks_ = [checkpointCallback, StoreEpochData()]
kwargsApplied = {
'logger': pl.loggers.TensorBoardLogger(self.modelName, name=architectureName,
version='preRunTests'),
'callbacks': callbacks_, }
return self.baseFit(trainDataloader=trainDataloader, valDataloader=valDataloader,
addDefaultLogger=addDefaultLogger,
addDefault_gradientClipping=addDefault_gradientClipping,
listOfKwargs=[kwargsApplied], **kwargs)
@argValidator
def baseFit(self, trainDataloader: DataLoader,
valDataloader: Union[DataLoader, None] = None,
addDefaultLogger=True, addDefault_gradientClipping=True,
listOfKwargs: List[dict] = None,
**kwargs):
# cccUsage
# - this method accepts kwargs related to trainer, trainer.fit, and self.log and
# pass them accordingly
# - the order in listOfKwargs is important
# - _logOptions phase based values feature:
# - args related to self.log may be a dict with these keys 'train', 'val', 'test',
# 'predict' or 'else'
# - this way u can specify what phase use what values and if not specified with
# 'else' it's gonna know
# put together all kwargs user wants to pass to trainer, trainer.fit, and self.log
listOfKwargs = listOfKwargs or []
listOfKwargs.append(kwargs)
allUserKwargs = {}
for kw in listOfKwargs:
self._plKwargUpdater(allUserKwargs, kw)
# add default logger if allowed and no logger is passes
# because by default we are logging some metrics
if addDefaultLogger and 'logger' not in allUserKwargs:
allUserKwargs['logger'] = pl.loggers.TensorBoardLogger(self.modelName)
# bugPotentialCheck1
# shouldn't this default logger have architectureName
appliedKwargs = self._getArgsRelated_toEachMethodSeparately(allUserKwargs)
notAllowedArgs = ['self', 'overfit_batches', 'name', 'value']
self._removeNotAllowedArgs(allUserKwargs, appliedKwargs, notAllowedArgs)
self._warnForNotUsedArgs(allUserKwargs, appliedKwargs)
# add gradient clipping by default
if not self.noAdditionalOptions and addDefault_gradientClipping \
and 'gradient_clip_val' not in appliedKwargs['trainer']:
appliedKwargs['trainer']['gradient_clip_val'] = 0.1
Warn.info('gradient_clip_val is not provided to fit;' + \
' so by default it is set to default "0.1"' + \
'\nto cancel it, you may either pass noAdditionalOptions=True to model or ' + \
'pass addDefault_gradientClipping=False to fit method.' + \
'\nor set another value to "gradient_clip_val" in kwargs passed to fit method.')
trainer = pl.Trainer(**appliedKwargs['trainer'])
self._logOptions = appliedKwargs['log']
if 'train_dataloaders' in appliedKwargs['trainerFit']:
del appliedKwargs['trainerFit']['train_dataloaders']
if 'val_dataloaders' in appliedKwargs['trainerFit']:
del appliedKwargs['trainerFit']['val_dataloaders']
trainer.fit(self, trainDataloader, valDataloader, **appliedKwargs['trainerFit'])
self._logOptions = {}
return trainer
def on_save_checkpoint(self, checkpoint: dict):
# reimplement this method to save additional information to the checkpoint
# Add additional information to the checkpoint
checkpoint['brazingTorch'] = {
'_initArgs': self._initArgs,
'allDefinitions': self.allDefinitions,
'warnsFrom_getAllNeededDefinitions': self.warnsFrom_getAllNeededDefinitions,
}
return checkpoint
on_save_checkpointmust be defined on the main class inheriting frompytorch lightning LightningModule(hereBrazingTorch) and not it's parents