slice Pytorch tensors which are saved in a list

451 Views Asked by At

I have the following code segment to generate random samples. The generated samples is a list, where each entry of the list is a tensor. Each tensor has two elements. I would like to extract the first element from all tensors in the list; and extract the second element from all tensors in the list as well. How to perform this kind of tensor slice operation

import torch
import pyro.distributions as dist
num_samples = 250
# note that both covariance matrices are diagonal
mu1 = torch.tensor([0., 5.])
sig1 = torch.tensor([[2., 0.], [0., 3.]])
dist1 = dist.MultivariateNormal(mu1, sig1)
samples1 = [pyro.sample('samples1', dist1) for _ in range(num_samples)]

samples1

enter image description here

3

There are 3 best solutions below

0
rkechols On BEST ANSWER

I'd recommend torch.cat with a list comprehension:

col1 = torch.cat([t[0] for t in samples1])
col2 = torch.cat([t[1] for t in samples1])

Docs for torch.cat: https://pytorch.org/docs/stable/generated/torch.cat.html

ALTERNATIVELY

You could turn your list of 1D tensors into a single big 2D tensor using torch.stack, then do a normal slice:

samples1_t = torch.stack(samples1)
col1 = samples1_t[:, 0]  # : means all rows
col2 = samples1_t[:, 1]

Docs for torch.stack: https://pytorch.org/docs/stable/generated/torch.stack.html

0
Ivan On

I should mention PyTorch tensors come with unpacking out of the box, this means you can unpack the first axis into multiple variables without additional considerations. Here torch.stack will output a tensor of shape (rows, cols), we just need to transpose it to (cols, rows) and unpack:

>>> c1, c2 = torch.stack(samples1).T

So you get c1 and c2 shaped (rows,):

>>> c1
tensor([0.6433, 0.4667, 0.6811, 0.2006, 0.6623, 0.7033])

>>> c2
tensor([0.2963, 0.2335, 0.6803, 0.1575, 0.9420, 0.6963])
0
ayandas On

Other answers that suggest .stack() or .cat() are perfectly fine from PyTorch perspective.

However, since the context of the question involves pyro, may I add the following:

Since you are doing IID samples

[pyro.sample('samples1', dist1) for _ in range(num_samples)]

A better way to do it with pyro is

dist1 = dist.MultivariateNormal(mu1, sig1).expand([num_samples])

This tells pyro that the distribution is batched with a batch size of num_samples. Sampling from this will produce

>> dist1.sample()
tensor([[-0.8712,  6.6087],
    [ 1.6076, -0.2939],
    [ 1.4526,  6.1777],
    ...
    [-0.0168,  7.5085],
    [-1.6382,  2.1878]])

Now its easy to solve your original question. Just slice it like

samples = dist1.sample()
samples[:, 0]  # all first elements
samples[:, 1]  # all second elements