I suspect part of my Pytorch model is not saved. How to verify this?

44 Views Asked by At

I am tesing a BiLSTM + CRF model for Named Entiry Recognition.

When I saved the model and reload it, it performs poorly like a newly initialized model, but learns much faster than a real new model.

The following is its per-tag F1 on the test set. The break point is where I stop and reload. I calculate these values with PyCM.

enter image description here

I suspect part of my model is not really saved,like Pytorch-CRF, because it's from a non-pytorch lib. But I don't know why and how to verify this.

This is how my model is defined (simplified):

import torch
import torchcrf

class NerModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.num_classes = 7
        self.embedding_dim = 300
        self._bilstm = torch.nn.LSTM(
            input_size = self.num_classes, 
            hidden_size = self.num_classes,
            batch_first = True,
            bidirectional = True
        )
        self._linear = torch.nn.Linear(self.embedding_dim * 2, self.num_classes)      
        self._crf = torchcrf.CRF(7, batch_first = True)
    def forward(self, embeddings:torch.Tensor, mask:torch.Tensor) -> torch.Tensor:
        out, _ = self._bilstm(embeddings)
        out = self._linear(out)
        out = self._crf.decode(emissions = out, mask = mask)
        padded_out = torch.ones(mask.size(), dtype=torch.int64) * 7
        for ib, sentence_label in enumerate(out):
            for ic, token_label in enumerate(sentence_label):
                padded_out[(ib, ic)] = token_label
        out = padded_out
    def calculate_neg_log_likelihood_loss(self, embeddings:torch.Tensor, mask:torch.Tensor, gold_labels):
        out, _ = self._bilstm(embeddings)
        out = self._linear(out)
        loss = - self._crf(emissions = out, tags = gold_labels, mask = (mask == 1))
        return loss

This is how I save the model:

    def _save_state(self, i_epoch:int, batch_count:int) -> None:
        torch.save(
            {
                'epoch' : i_epoch,
                'batch_count' : batch_count,
                'model_state' : self.model.state_dict(),
                'optimizer_state' : self.optimizer.state_dict()
            },
            f'./saves/model.pt'
        )
0

There are 0 best solutions below