Accessing the N values of PyTorch's Cross-Entropy loss function

45 Views Asked by At

I am trying to access specific values of PyTorch's Cross-Entropy loss function (torch.nn.functional.cross_entropy) that I believe are being calculated when the input is a vector of length N. I would like access to the vector of N individual loss values; not the mean or sum or whathaveyou as I believe is what is being returned.

From looking at the documentation, I tried setting "reduction = None"; however, it still returns a scalar value. The function's default setting is to return the mean.

Here is the error message:

Exception has occurred: IndexError

slice() cannot be applied to a 0-dim tensor.

IndexError: slice() cannot be applied to a 0-dim tensor.

Here's is the code snippet leading up to the error:

tMinusOne_loss = burnIn_model.loss(combined_tMinusOne_X, combined_tMinusOne_Y)

print("tMinusOne_loss:", tMinusOne_loss)

tMinusOne_first_loss = tMinusOne_loss[ :len(combined_tMinusOne_X_first)]

Here is what is printed out from the print line right before the error occurs:

tMinusOne_loss: tensor(0.3171, grad_fn=<NllLossBackward0>)

Thank you!

1

There are 1 best solutions below

1
Karl On

If I'm understanding the problem correctly, you are computing the cross entropy loss of a vector of size (N) or (1, N) (ie a single item, not a batch). In this case, there is only one expected loss value. CrossEntropyLoss produces one value per item in the batch.

loss = nn.CrossEntropyLoss(reduction='none')

n_classes = 5
batch_size = 3
input = torch.randn(batch_size, n_classes)
target = torch.empty(batch_size, dtype=torch.long).random_(n_classes)
output = loss(input, target)
output.shape
> torch.Size([3]) # output is one value for each item in the batch

batch_size = 1
input = torch.randn(batch_size, n_classes)
target = torch.empty(batch_size, dtype=torch.long).random_(n_classes)
output = loss(input, target)
output.shape
> torch.Size([1]) # output is one value for each item in the batch

Cross entropy is computed as the negative log prob of the expected class. Since there is only one target class per item, there is only one loss value per item.

loss = nn.CrossEntropyLoss(reduction='none')

n_classes = 5
batch_size = 1
input = torch.randn(batch_size, n_classes)
target = torch.empty(batch_size, dtype=torch.long).random_(n_classes)

loss_value = loss(input, target)

log_probs = nn.functional.log_softmax(input, dim=-1)

torch.allclose(-log_probs[:, target], loss_value) # CE loss is the same as negative log prob of expected class 
> True