pytorch all_gather gives wrong output order

22 Views Asked by At

I was using pytorch function torch.nn.parallel.DistributedDataParallel to run batches on multi-gpu. However, I found the output after torch.distributed.all_gather(gather_outputs, out) is wrong. for example, the data order is 0,1,2,3, but the output order is 0,1,3,2.

here is my code

model = torch.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False)
outputs = []
for inp, label in test_loader:
    inp = inp.cuda()
    label = label.cuda()
    out = model(inp)
    gather_outputs = [torch.zeros_like(out) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(gather_outputs, out) # not ordered
    outputs.extend(gather_outputs)
0

There are 0 best solutions below