I am quite new to Pytorch and currently running into issues with Memory Overflow.
Task: I have two 2D tensors of respective shapes A: [1000, 14] & B: [100000, 14].
I have to find the distance of each row of tensor-A from all rows from tensor-B. Later using the calculated distance values, I find the mean of minimum/mean distance of each row of tensor-A from tensor-B.
Current Solution: My solution to calculate minimum distance:
dist = list()
for row_id in range(A.shape[0]):
# Mean distance of a row in A from B
dist.append(torch.linalg.norm(A[row_id, :] - B, dim=1).min().item())
result = torch.FloatTensor(dist).mean()
And solution to calculate minimum mean distance:
dist = list()
for row_id in range(A.shape[0]):
# Mean distance of a row in A from B
dist.append(torch.linalg.norm(A[row_id, :] - B, dim=1).mean().item())
result = torch.FloatTensor(dist).mean()
Issue: This gives me result but is either very slow (if run on CPU) or often leads to memory overflow in GPU when trying to run on GPU. (I have a T4 GPU - 8GB)
Can you please recommend me a better solution to calculate the Euclidean distance that is faster and does not lead to overflow issues?
Thanks!
Sure.
The idea is to use the fact that
norm(x-y) = norm(x)^2 + norm(y)^2 -2xyand the outer product.So, see the following code for the case of minimum:
The last line compares the result of your implementation with mine.
Explanation:
x = torch.linalg.norm(A, dim=1)**2computes a vector whose elements are squared norms of the rows of A.Similarly for
y = torch.linalg.norm(B, dim=1)**2and the rows of B.o1 = torch.outer(x, torch.ones(B.shape[0]))is a matrix of as many identical columns as there are rows in B. Each column is a vector of square norms of A.Similarly,
o2 = torch.outer(torch.ones(A.shape[0]), y)is a matrix with as many identical rows as there rows in A and each row has the squared norms of the rows of B.So, the matrix
o1+o2' is such that at indicesi, j` the value is the squared norm of the ith row of A plush the squared norm of the jth row of B.What remain is to subtract twice the inner product of the ith row of A with jth row of B which is done using
n = o1+o2 - 2 * [email protected]()Now,
sqrt(n)has the euclidean distance between the ith row of A and jth row of B at indicesi, j. What's left is to find the minimum or mean, in this case - minimum.