How to compute batched covariance of three tensors in PyTorch?

58 Views Asked by At

Assume that, we have 3 tensors of size (B, C, H, W), where B is the batch size and C is the channel dimension. I'm expecting to compute the covariance of these 3 tensors along the channel dimension.

I have tried the following code:

x1_mean = x1.mean(dim=1).unsqueeze(dim=1)
x2_mean = x2.mean(dim=1).unsqueeze(dim=1)
x3_mean = x3.mean(dim=1).unsqueeze(dim=1)
out = torch.matmul(torch.matmul(x1 - x1_mean, x2 - x2_mean), x3 - x3_mean)

Just wondering if my code makes sense. And is there another way to compute the covariance? Any help would be greatly appreciated.

1

There are 1 best solutions below

0
inverted_index On

Your approach to computing the covariance of three tensors along the channel dimension is a good start, but it seems there might be a misunderstanding about how covariance is calculated, especially for multiple tensors. You're computing the mean along the channel dimension and then unsqueezing it to match the original tensor's shape. This part is correct for mean centering the tensors. However, the use of torch.matmul in your code does not correctly compute the covariance. Covariance typically involves pairwise computation between two sets of variables, not three, and is computed differently.

With all that above, assuming you want to do some form of multivariate covariance between three tensors, the following code does the work for you:

# step 1: reshape tensors
B, C, H, W = x1.shape
x_combined = torch.cat([x1, x2, x3], dim=1)  # Shape: (B, C*3, H, W)
x_combined = x_combined.reshape(B, C*3, H*W) # Shape: (B, C*3, H*W)

# step 2: mean centering
mean_centered = x_combined - x_combined.mean(dim=2, keepdim=True)

# step 3: covariance matrix calculation
cov_matrix = torch.matmul(mean_centered, mean_centered.transpose(1, 2)) / (H*W - 1)

This code snippet will yield a batch of covariance matrices, each of size (C3, C3), representing the covariance between every pair of channels across all three tensors at each spatial location. Each matrix is a multivariate covariance matrix for one example in the batch.