PyTorch custom forward function does not work with DataParallel

1.4k Views Asked by At

Edit: I have tried PyTorch 1.6.0 and 1.7.1, both give me the same error.

I have a model that allows users to easily switch between different architectures A and B. The forward functions for both architectures are different too, so I have the following model class:

P.S. I am just using a very simple example here to demonstrate my problem, the actual model is much more complicated.

class Net(nn.Module):
    def __init__(self, condition):
        super().__init__()
        self.linear = nn.Linear(10, 1)
        
        if condition == 'A':
            self.forward = self.forward_A
        elif condition == 'B':
            self.linear2 = nn.Linear(10, 1)
            self.forward = self.forward_B
            
    def forward_A(self, x):
        return self.linear(x)
    
    def forward_B(self, x1, x2):
        return self.linear(x1) + self.linear2(x2)
    

It works well in a single GPU case. In the multi-GPU case, however, it throws me an error.

device= 'cuda:0'
x = torch.randn(8,10).to(device)

model = Net('B')
model = model.to(device)
model = nn.DataParallel(model)

model(x, x)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_addmm)

How to make this model class works with nn.DataParallel?

2

There are 2 best solutions below

5
Shai On

You are forcing the input x and the model to be on 'cuda:0' device, but when working on multiple GPUs, you should not specify any particular device.
Try:

x = torch.randn(8,10)  
model = Net('B')
model =  nn.DataParallel(model, device-ids=[0, 1]).cuda()  # assuming 2 GPUs
pred = model(x, x)
1
Eric Hedlin On

This problem goes away if you have 2 wrappers each calling this model with their own forward functions.

Also you need to use nn.DataParallel instead of nn.Module.