RuntimeError: expected scalar type Double but found Float in Pytorch code

117 Views Asked by At
def encoder_block(inp, max_pool, in_channels):
    conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, padding='same')(inp.double())
    relu = torch.nn.ReLU()(conv)
    conv = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same')(relu)
    relu = torch.nn.ReLU()(conv)
    if max_pool:
        return torch.nn.MaxPool2d(2,2)(relu)
    return relu

test_load = nib.load(fpath).get_fdata()
test_numpy = test_load[:,:,0].reshape(1,1,256,256).astype(np.double)
tens = torch.DoubleTensor(test_numpy)
out = encoder_block(tens, True, 1)

This code should take a nifti file from my local storage, convert it to a numpy array then perform some convolutions on the 2d image as a basic test just for now.

The error happens on the first conv2d and says RuntimeError: expected scalar type Double but found Float. I'm not sure what else I can do to convert my data to float.

2

There are 2 best solutions below

0
pythonic833 On

This is simply a conversion problem. Pytorch uses torch.float32 by default but you explicitly create a tensor of type torch.float64. So you also have to tell the convolutional layers the correct dtype using the keyword argument dtype. A working example of the code looks like:

def encoder_block(inp, max_pool, in_channels):
    conv = torch.nn.Conv2d(in_channels=in_channels, 
                           out_channels=64, 
                           kernel_size=3, 
                           padding='same', 
                           dtype=torch.float64)(inp)
    relu = torch.nn.ReLU()(conv)
    conv = torch.nn.Conv2d(in_channels=64, 
                           out_channels=64, 
                           kernel_size=3,
                           padding='same', 
                           dtype=torch.float64)(relu)
    relu = torch.nn.ReLU()(conv)
    if max_pool:
        return torch.nn.MaxPool2d(2,2)(relu)
    return relu

test_load = nib.load(fpath).get_fdata()
# tested with the next line
# test_load = np.random.rand(256, 256, 1)
test_numpy = test_load[:,:,0].reshape(1,1,256,256)
tens = torch.DoubleTensor(test_numpy)
out = encoder_block(tens, True, 1)
0
Chris Markiewicz On

I'm not sure what else I can do to convert my data to float.

get_fdata() takes a dtype parameter.

test_load = nib.load(fpath).get_fdata(dtype=np.float32)