How to add a custom CRF head on top of BERT for token classification?

204 Views Asked by At

I want to write custom forward loop in Pytorch and add a CRF head on top of the pre-trained BERT model for Token Classification (Sequence Tagging). I want to write custom forward loop in PyTorch and add a CRF head on top of the pre-trained BERT model for Token Classification (Sequence Tagging). I tried these methods but it give me the following errors.

Method # 1:

    bert = BertModel.from_pretrained(MODEL, id2label=id2label, label2id=label2id)
    config=bert.config

    class BERT_CRF(nn.Module):
        def __init__(self):
            super(BERT_CRF, self).__init__()
            self.bert = bert
            self.dropout = nn.Dropout(0.1)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.crf = CRF(config.num_labels, batch_first=True)

        def forward(self, input_ids,token_type_ids,attention_mask,labels):
        outputs=self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
            sequence_output = outputs[0]
            sequence_output = self.dropout(sequence_output)
            logits = self.classifier(sequence_output)
            loss = -self.crf(emissions = logits, tags=labels, mask=attention_mask)
            return loss

    model = BERT_CRF().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    model.train()
    for batch in train_dataloader:
        model.zero_grad()
        loss = model(**batch.to(device))
        loss.backward()
        optimizer.step()

Method # 2

    class BERT_CRF(nn.Module):
        def __init__(self):
            super(BERT_CRF, self).__init__()
            self.bert = bert
            self.dropout = nn.Dropout(0.1)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.crf = CRF(config.num_labels, batch_first=True)

        def tag_outputs(self, input_ids, token_type_ids=None, attention_mask=None):
            outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
            sequence_output = outputs[0]
            sequence_output = self.dropout(sequence_output)
            emissions = self.classifier(sequence_output)
            return emissions

        def forward(self, input_ids,token_type_ids,attention_mask,labels):
            emissions = self.tag_outputs(input_ids, token_type_ids, attention_mask)
            loss = -1*self.crf(emissions, labels, mask=attention_mask.byte())
            return loss
    
    IndexError                                Traceback (most recent call last)
    Cell In[48], line 5
          3 for batch in train_dataloader:
          4     model.zero_grad()
    ----> 5     loss = model(**batch.to(device))
          6     print(loss)
          7     break

    File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    Cell In[47], line 15, in BERT_CRF.forward(self, input_ids, token_type_ids, attention_mask, labels)
         13 logits = self.classifier(sequence_output)
         14 outputs = (logits,)
    ---> 15 loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
         16 outputs =(-1*loss,)+outputs
         17 return outputs

    File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
       1496 # If we don't have any hooks, we want to skip the rest of the logic in
       1497 # this function, and just call forward.
       1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
       1499         or _global_backward_pre_hooks or _global_backward_hooks
       1500         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1501     return forward_call(*args, **kwargs)
       1502 # Do not call functions when jit is used
       1503 full_backward_hooks, non_full_backward_hooks = [], []

    File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torchcrf\__init__.py:102, in CRF.forward(self, emissions, tags, mask, reduction)
         99     mask = mask.transpose(0, 1)
        101 # shape: (batch_size,)
    --> 102 numerator = self._compute_score(emissions, tags, mask)
        103 # shape: (batch_size,)
        104 denominator = self._compute_normalizer(emissions, mask)

    File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torchcrf\__init__.py:186, in CRF._compute_score(self, emissions, tags, mask)
        182 mask = mask.float()
        184 # Start transition score and first emission
        185 # shape: (batch_size,)
    --> 186 score = self.start_transitions[tags[0]]
        187 score += emissions[0, torch.arange(batch_size), tags[0]]
        189 for i in range(1, seq_length):
        190     # Transition score to next tag, only added if next timestep is valid (mask == 1)
        191     # shape: (batch_size,)

    IndexError: index -100 is out of bounds for dimension 0 with size 2
0

There are 0 best solutions below