我希望“拉伸” pytorch张量的最后两个维度,以提高(批,通道,y,x)张量的空间分辨率。
最小示例(我需要“ new_function”)
a = torch.tensor([[1, 2], [3, 4]])
b = new_function(a, (2, 3))
print(b)
tensor([[1, 1, 1, 2, 2, 2],
[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4],
[3, 3, 3, 4, 4, 4]])
一种解决方法(针对实际问题):
a = torch.ones((2, 256, 2, 2)) # my original data.
b = torch.zeros((2, 256, 80, 96)) # The output I need
b[:, :, :40, :48] = a[:, :, 0, 0]
b[:, :, 40:, :48] = a[:, :, 1, 0]
b[:, :, :40, 48:] = a[:, :, 0, 1]
b[:, :, 40:, 48:] = a[:, :, 1, 1]
答案 0 :(得分:0)
使用torch.nn.functional.interpolate
(感谢Shai)
torch.nn.functional.interpolate(input_tensor.float(), size=(4, 6))
我最初的想法是使用各种view
和repeat
方法:
def stretch(e, sdims):
od = e.shape
return e.view(od[0], od[1], -1, 1).repeat(1, 1, 1, sdims[-1]).view(od[0], od[1], od[2], -1).repeat(1, 1, 1, sdims[-2]).view(od[0], od[1], od[2] * sdims[0], od[3] * sdims[1])
torch.Size([2, 2, 4, 6])