class DistilBertForSpanCategorization(DistilBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.distilbert = DistilBertModel(config)
### for ner ###
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
### for argument classification ###
self.dropout_args = nn.Dropout(config.dropout)
self.classifier_args = nn.Linear(config.hidden_size, 3)
self.init_weights()
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# for NER
sequence_output = outputs[0]
#print("sequence_output 1 shape: ",sequence_output.shape)
sequence_output = self.dropout(sequence_output)
#print("sequence_output 2 shape: ",sequence_output.shape)
logits = self.classifier(sequence_output)
#print("logits shape: ",logits.shape)
# for argument identification
arg_sequence_output = outputs[0]
arg_sequence_output = self.dropout_args(arg_sequence_output)
#print("arg_sequence_output after dropout shape: ",arg_sequence_output.shape)
logits_args = self.classifier_args(arg_sequence_output)
#print("arg_sequence_output after linear shape: ",logits_args.shape)
loss = None
logits_args= None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
loss_spans = loss_fct(logits, labels.float())
loss_fct = nn.BCEWithLogitsLoss()
loss_args = loss_fct(logits_args, labels.float())
loss=loss_spans+loss_args # combining the loss of ner and argument classification
if not return_dict: ### how to modify this to for multi-task
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( ### how to modify this to for multi-task
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Above is the code modeling BERT to do NER and span classification, I don't know how to modify the return in forward function to accommodate both tasks, what should be "logits" for "TokenClassifierOutput" as there are two logits computed for getting the loss for each task. Also, the "if not return_dict:" line.