How to modify the code below to adapt to multi-task using bert

41 Views Asked by At
class DistilBertForSpanCategorization(DistilBertPreTrainedModel):
def __init__(self, config):
    super().__init__(config)
    self.num_labels = config.num_labels

    self.distilbert = DistilBertModel(config)

    ### for ner ###
    self.dropout = nn.Dropout(config.dropout)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    
    ### for argument classification ###
    self.dropout_args = nn.Dropout(config.dropout)
    self.classifier_args = nn.Linear(config.hidden_size, 3)
   
    self.init_weights()

@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    r"""
    labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
        Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
        1]``.
    """
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.distilbert(
        input_ids,
        attention_mask=attention_mask,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # for NER
    sequence_output = outputs[0]
    #print("sequence_output 1 shape: ",sequence_output.shape)
    sequence_output = self.dropout(sequence_output)
    #print("sequence_output 2 shape: ",sequence_output.shape)
    logits = self.classifier(sequence_output)
    #print("logits shape: ",logits.shape)

    # for argument identification
    arg_sequence_output = outputs[0]
    arg_sequence_output = self.dropout_args(arg_sequence_output)
    #print("arg_sequence_output after dropout shape: ",arg_sequence_output.shape)
    logits_args = self.classifier_args(arg_sequence_output)
    #print("arg_sequence_output after linear shape: ",logits_args.shape)


    loss = None
    logits_args= None
    if labels is not None:
        loss_fct = nn.BCEWithLogitsLoss()
        loss_spans = loss_fct(logits, labels.float())

        loss_fct = nn.BCEWithLogitsLoss()
        loss_args = loss_fct(logits_args, labels.float())

        loss=loss_spans+loss_args # combining the loss of ner and argument classification


    if not return_dict:  ### how to modify this to for multi-task    
        output = (logits,) + outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return TokenClassifierOutput(  ### how to modify this to for multi-task
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

Above is the code modeling BERT to do NER and span classification, I don't know how to modify the return in forward function to accommodate both tasks, what should be "logits" for "TokenClassifierOutput" as there are two logits computed for getting the loss for each task. Also, the "if not return_dict:" line.

0

There are 0 best solutions below