How to make my code save checkpoints correctly regarding validation loss?

54 Views Asked by At

I want a single checkpoint file which is constantly overwritten by the next checkpoint. It must save the epoch + training loss + validation loss from the best validation loss epoch so it continues training correctly in the next training session.

My code for this:

    # Save the checkpoint only if the validation loss improves
    if val_loss < best_loss:
        best_loss = val_loss
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_loss,
            'log': log
        }, checkpoint_path)

Problem: When it resumes training, it continues training after the last epoch from the last training and not the last epoch which had the best validation loss (example: last epoch was epoch 90 and the last epoch with best validation loss was epoch 80, so it should continue training epoch 81 in the next training session).

I did a test training and trained until it didn't have a best validation loss which was epoch 5. Epoch 4 was the last epoch with best validation loss.

I exited the training while it was training epoch 6 and resumed training.

It continued training at epoch 6 but it should continue training epoch 5 because epoch 4 was the last epoch from the previous training which had a best validation loss.

0

There are 0 best solutions below