Pytorch lightning loading a checkpoint

163 Views Asked by At

This is my code, my task is mulit-label Classification and trained this for 3 epochs and now I want to load it and use my checkpoint.

class QTagClassifier(pl.LightningModule):
    # Set up the classifier
    def __init__(self, n_classes=10, steps_per_epoch=None, n_epochs=3, lr=2e-5 ):
        super().__init__()

        self.bert = AutoModel.from_pretrained("bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12",from_flax=True,problem_type="multi_label_classification",num_labels=10)# works
        

        for param in self.bert.parameters():
            param.requires_grad = False

        self.classifier = torch.nn.Sequential(
                          torch.nn.Flatten(),
                          torch.nn.Linear(768, 64),
                          torch.nn.BatchNorm1d(64),
                          torch.nn.ReLU(),
                          torch.nn.Linear(64, 32),
                          torch.nn.BatchNorm1d(32),
                          torch.nn.ReLU(),
                          torch.nn.Linear(32, 10)
                        )
        #self.classifier = torch.nn.Linear(self.bert.config.hidden_size,n_classes) # outputs = number of labels
        self.steps_per_epoch = steps_per_epoch
        self.n_epochs = n_epochs
        self.lr = lr
        self.criterion = torch.nn.BCEWithLogitsLoss()

    def forward(self,input_ids, attn_mask):
        output = self.bert(input_ids = input_ids ,attention_mask = attn_mask)
        output = self.classifier(output.pooler_output)

        return output


    def training_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']

        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
        self.log('train_loss',loss , prog_bar=True,logger=True)

        return {"loss" :loss, "predictions":outputs, "labels": labels }


    def validation_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']

        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
        self.log('val_loss',loss , prog_bar=True,logger=True)

        return loss

    def test_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']

        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
        self.log('test_loss',loss , prog_bar=True,logger=True)

        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters() , lr=self.lr)
        warmup_steps = self.steps_per_epoch//3
        total_steps = self.steps_per_epoch * self.n_epochs - warmup_steps

        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,total_steps)

        return [optimizer], [scheduler]

how I make the class:

steps_per_epoch = len(x_tr)//BATCH_SIZE
model = QTagClassifier(n_classes=10, steps_per_epoch=steps_per_epoch,n_epochs=N_EPOCHS,lr=LR)

I read the pytorch lightening doc and it said save checkpoints like this: trainer.save_checkpoint("example.ckpt") this is how I wanted load my checkpoint: new_model = model .load_from_checkpoint(checkpoint_path="example.ckpt")

I did it, but when I want to load my ckeckpoint i get this error:

Error(s) in loading state_dict for QTagClassifier:
    Unexpected key(s) in state_dict: "bert.embeddings.position_ids".

any help will be useful, Thank you

also tryed to load with trainer trainer.fit(model, QTdata_module,ckpt_path="example.ckpt") same thing happens

1

There are 1 best solutions below

0
SehCode On

Did you use a different computer/python environment or upgraded your packages when you loaded your trained model? I had the same problem and that was my issue. I fixed it by downgrading the transformers version to match my other computer. If you upgraded your package and unsure what you were using before, try

transformers==4.26.1 

and work your way down the versions