在PyTorch中按列表在轴上编制索引

时间:2018-03-05 05:57:36

标签: indexing pytorch

我的变量length_X大小(10L,)和A大小(10L,16L,5L)。

我想使用lengths_X来沿着A的第二轴进行索引。换句话说,我想得到一个新的张量predict_Y大小(10L,5L),它将i轴上的轴索引为索引为i的所有条目0

在PyTorch中执行此操作的最佳方法是什么?

1 个答案:

答案 0 :(得分:2)

您正在寻找的实际上是batched_index_select,我之前寻找过这样的功能,但是在PyTorch中找不到可以完成任务的任何本机功能。但我们可以简单地使用:

A = torch.randn(10, 16, 5)
index = torch.from_numpy(numpy.random.randint(0, 16, size=10))
B = torch.stack([a[i] for a, i in zip(A, index)])

您可以看到讨论here。您还可以查看batched_index_select库中提供的功能AllenNLP。我很乐意知道是否有更好的解决方案。