Pytorch apply operation to some rows but not all

362 Views Asked by At

I'm trying to apply a softmax function on some rows of a tensor, but the problem is that some of my rows have all -inf values. As such, softmax on these rows outputs NaN, which causes problems later in the model.

As such, I want to create a function that applies softmax to a row unless in is all -inf. In that case, it outputs a zero vector. Is there any easy way to do this?

1

There are 1 best solutions below

0
On

Would something like setting all rows that are all nan after the softmax to 0 work for you? This way you make sure that you are not overwriting any unexpected nans.

import numpy as np
import torch
import torch.nn.functional as F

array = np.arange(25, dtype=np.float32).reshape((5, 5))
array[3, ...] *= -np.inf

# [[  0.   1.   2.   3.   4.]
#  [  5.   6.   7.   8.   9.]
#  [ 10.  11.  12.  13.  14.]
#  [-inf -inf -inf -inf -inf]
#  [ 20.  21.  22.  23.  24.]]

array = torch.tensor(array)
array = F.softmax(array, dim=1)
mask = array.isnan().all(dim=1)
array[mask, ...] = 0
print(array)

# tensor([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
#         [0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
#         [0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
#         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.0117, 0.0317, 0.0861, 0.2341, 0.6364]])