我正在尝试找出如何对张量为b (batch size), d (depth), h (hight) and w (width)
的张量进行移位:
b, d, h, w = tensor.size()
因此,我需要找到已移位张量与张量本身之间的减法。
我正在考虑使用torch.narrow
或torch.concat
来进行每一侧的操作(先左右,先左后右移动),然后每次从同一张量侧减去(张量本身的一面),然后最后我将对每一边的差/减相加/求和(这样,我将在平移张量本身之间得到最后的减法。
我是PyTorch的新手,它很容易理解,但难以实现,并且也许有一个更简单的方法(直接进行减法,而不是在各个方面进行工作等等。)
请问有什么帮助吗?
答案 0 :(得分:0)
基本上,您可以先分割张量,然后以相反的顺序排列它们。我编写了一个函数来实现您的想法。 shift
应该是一个非负数,并且小于或等于dim
的大小。
def tensor_shift(t, dim, shift):
"""
t (tensor): tensor to be shifted.
dim (int): the dimension apply shift.
shift (int): shift distance.
"""
assert 0 <= shift <= t.size(dim), "shift distance should be smaller than or equal to the dim length."
overflow = t.index_select(dim, torch.arange(t.size(dim)-shift, t.size(dim)))
remain = t.index_select(dim, torch.arange(t.size(dim)-shift))
return torch.cat((overflow, remain),dim=dim)
以下是一些测试结果。
a = torch.arange(1,13).view(-1,3)
a
#tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
shift(a, 0, 1) # shift 1 unit along dim=0
#tensor([[10, 11, 12],
# [ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]])
b = torch.arange(1,13).view(-1,2,3)
b
#tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
#
# [[ 7, 8, 9],
# [10, 11, 12]]])
shift(b, 1, 1) # shift 1 unit along dim=1
#tensor([[[ 4, 5, 6],
# [ 1, 2, 3]],
#
# [[10, 11, 12],
# [ 7, 8, 9]]])