Pytorch张量索引:如何通过包含索引的张量收集行

时间:2019-04-27 13:22:40

标签: python indexing pytorch

我有张量:

id :形状(7000,1),其中包含[[1],[0],[2],...]

x :形状(7000, 3 ,255)

张量id编码应选择的以粗体标记的x索引。 我想将选定的切片收集到结果向量中:

结果:形状(7000,255)

背景

我对这3个元素中的每一个都有一些得分(形状=(7000,3)),并且只想选择得分最高的那个。因此我使用了该功能

ids = torch.argmax(scores,1,True)

为我提供了最大的ID。我已经尝试使用collect函数来做到这一点:

result = x.gather(1,ids)

但这没用。

3 个答案:

答案 0 :(得分:1)

这是您可能需要的解决方案

longPress

以下示例:

ids = ids.repeat(1, 255).view(-1, 1, 255)

答案 1 :(得分:0)

使用David Ng的示例,我找到了另一种方法:

idx = ids.flatten() + torch.arange(0,4*3,3)

tensor([ 0,  5,  6, 11])



x.view(-1,2)[idx]

tensor([[ 0,  1],
        [10, 11],
        [12, 13],
        [22, 23]])

答案 2 :(得分:0)

在维度更高的情况下,另一种解决方案可能会提供更好的内存读取模式。

# data
x = torch.arange(60).reshape(3, 4, 5)
# index
y = torch.randint(0, 4, (12,), dtype=torch.int64).reshape(3, 4)
# result
z = x[torch.arange(x.shape[0]).repeat_interleave(x.shape[1]), y.flatten()]
z = z.reshape(x.shape)

x、y、z 的示例结果将是

Tensor([[[ 0,  1,  2,  3,  4],
     [ 5,  6,  7,  8,  9],
     [10, 11, 12, 13, 14],
     [15, 16, 17, 18, 19]],

    [[20, 21, 22, 23, 24],
     [25, 26, 27, 28, 29],
     [30, 31, 32, 33, 34],
     [35, 36, 37, 38, 39]],

    [[40, 41, 42, 43, 44],
     [45, 46, 47, 48, 49],
     [50, 51, 52, 53, 54],
     [55, 56, 57, 58, 59]]])
tensor([[1, 1, 2, 3],
    [3, 1, 1, 0],
    [1, 1, 1, 1]])
tensor([[[ 5,  6,  7,  8,  9],
     [ 5,  6,  7,  8,  9],
     [10, 11, 12, 13, 14],
     [15, 16, 17, 18, 19]],

    [[35, 36, 37, 38, 39],
     [25, 26, 27, 28, 29],
     [25, 26, 27, 28, 29],
     [20, 21, 22, 23, 24]],

    [[45, 46, 47, 48, 49],
     [45, 46, 47, 48, 49],
     [45, 46, 47, 48, 49],
     [45, 46, 47, 48, 49]]])