I am using a pretrained model from transformers which expects input type to be torch.HalfTensor. However, I have input of type torch.FloatTensor. How to convert?
Convert FloatTensor to HalfTensor in Pytorch?
1.2k Views Asked by Kilaru Vasudeva At
2
There are 2 best solutions below
1
On
To add to the answer of Learning is a mess: There are several ways to convert a tensor from float to half
import torch
t_f = torch.FloatTensor(3, 2)
print(t_f.dtype) # torch.float32
# all t_hx are of type torch.float16
t_h1 = t_f.half() # works for cpu and gpu tensors
t_h2 = t_f.type(torch.HalfTensor) # only for cpu tensors, use torch.cuda.HalfTensor for gpu tensor
t_h3 = t_f.to(torch.half) # .to() works also on models
See the docs for more info on torch tensor data types.
torch.Tensor.half()should do: