如何将火炬张量旋转随机数度

时间:2020-08-27 15:33:29

标签: python rotation pytorch

作为训练CNN的一部分,我正在使用包含inputs对象的数组<class 'torch.Tensor'>。我想将单个<class 'torch.Tensor'>对象旋转某个随机数度数x,如下所示:

def rotate(inputs, x):
    # Rotate inputs[0] by x degrees, x can take on any value from 0 - 180 degrees

我该怎么做?对于现有的实现,我只能发现torch具有rot90函数,但这将我限制为90度的倍数,这对我的情况没有帮助。

谢谢Vinny

1 个答案:

答案 0 :(得分:0)

要转换torch.tensor,可以使用scipy.ndimage.rotate函数(读取here),该函数旋转torch.tensor,但也将其转换为numpy.ndarray,因此您必须将其转换回torch.tensor。看到这个玩具示例。

功能

def rotate(inputs, x):
    return torch.from_numpy(ndimage.rotate(inputs, x, reshape=False))

详细说明:

import torch
from scipy import ndimage
alpha = torch.rand(3,3)
print(alpha.dtype)#torch.float32

angle_in_degrees = 45
output = ndimage.rotate(alpha, angle_in_degrees, reshape=False)

print(output.dtype) #numpy_array

output = torch.from_numpy(output) #convert it back to torch tensor

print(output.dtype)  #torch.float32

此外,如果有可能,您可以在将PIL图像转换为张量之前直接对其进行转换。要转换PIL图像,您可以使用内置的PyTorch torchvision.transforms.functional.rotate(读为here)。