求移位张量的减法

时间:2019-09-25 16:49:13

标签: tensorflow pytorch tensor

我正在尝试找出如何对张量为b (batch size), d (depth), h (hight) and w (width)的张量进行移位:

b, d, h, w = tensor.size()

因此,我需要找到已移位张量与张量本身之间的减法。

我正在考虑使用torch.narrowtorch.concat来进行每一侧的操作(先左右,先左后右移动),然后每次从同一张量侧减去(张量本身的一面),然后最后我将对每一边的差/减相加/求和(这样,我将在平移张量本身之间得到最后的减法。

我是PyTorch的新手,它很容易理解,但难以实现,并且也许有一个更简单的方法(直接进行减法,而不是在各个方面进行工作等等。)

请问有什么帮助吗?

1 个答案:

答案 0 :(得分:0)

基本上,您可以先分割张量,然后以相反的顺序排列它们。我编写了一个函数来实现您的想法。 shift应该是一个非负数,并且小于或等于dim的大小。

def tensor_shift(t, dim, shift):
    """
    t (tensor): tensor to be shifted. 
    dim (int): the dimension apply shift.
    shift (int): shift distance.
    """
    assert 0 <= shift <= t.size(dim), "shift distance should be smaller than or equal to the dim length."

    overflow = t.index_select(dim, torch.arange(t.size(dim)-shift, t.size(dim)))
    remain = t.index_select(dim, torch.arange(t.size(dim)-shift))

    return torch.cat((overflow, remain),dim=dim)

以下是一些测试结果。

a = torch.arange(1,13).view(-1,3)
a
#tensor([[ 1,  2,  3],
#        [ 4,  5,  6],
#        [ 7,  8,  9],
#        [10, 11, 12]])

shift(a, 0, 1) # shift 1 unit along dim=0
#tensor([[10, 11, 12],
#        [ 1,  2,  3],
#        [ 4,  5,  6],
#        [ 7,  8,  9]])

b = torch.arange(1,13).view(-1,2,3)
b
#tensor([[[ 1,  2,  3],
#         [ 4,  5,  6]],
#
#        [[ 7,  8,  9],
#         [10, 11, 12]]])

shift(b, 1, 1) # shift 1 unit along dim=1
#tensor([[[ 4,  5,  6],
#         [ 1,  2,  3]],
#
#        [[10, 11, 12],
#         [ 7,  8,  9]]])