我有张量:
id :形状(7000,1),其中包含[[1],[0],[2],...]
x :形状(7000, 3 ,255)
张量id编码应选择的以粗体标记的x索引。 我想将选定的切片收集到结果向量中:
结果:形状(7000,255)
背景:
我对这3个元素中的每一个都有一些得分(形状=(7000,3)),并且只想选择得分最高的那个。因此我使用了该功能
ids = torch.argmax(scores,1,True)
为我提供了最大的ID。我已经尝试使用collect函数来做到这一点:
result = x.gather(1,ids)
但这没用。
答案 0 :(得分:1)
这是您可能需要的解决方案
longPress
以下示例:
ids = ids.repeat(1, 255).view(-1, 1, 255)
答案 1 :(得分:0)
使用David Ng的示例,我找到了另一种方法:
idx = ids.flatten() + torch.arange(0,4*3,3)
tensor([ 0, 5, 6, 11])
x.view(-1,2)[idx]
tensor([[ 0, 1],
[10, 11],
[12, 13],
[22, 23]])
答案 2 :(得分:0)
在维度更高的情况下,另一种解决方案可能会提供更好的内存读取模式。
# data
x = torch.arange(60).reshape(3, 4, 5)
# index
y = torch.randint(0, 4, (12,), dtype=torch.int64).reshape(3, 4)
# result
z = x[torch.arange(x.shape[0]).repeat_interleave(x.shape[1]), y.flatten()]
z = z.reshape(x.shape)
x、y、z 的示例结果将是
Tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
tensor([[1, 1, 2, 3],
[3, 1, 1, 0],
[1, 1, 1, 1]])
tensor([[[ 5, 6, 7, 8, 9],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[35, 36, 37, 38, 39],
[25, 26, 27, 28, 29],
[25, 26, 27, 28, 29],
[20, 21, 22, 23, 24]],
[[45, 46, 47, 48, 49],
[45, 46, 47, 48, 49],
[45, 46, 47, 48, 49],
[45, 46, 47, 48, 49]]])