我的变量length_X大小(10L,)和A大小(10L,16L,5L)。
我想使用lengths_X来沿着A的第二轴进行索引。换句话说,我想得到一个新的张量predict_Y大小(10L,5L),它将i轴上的轴索引为索引为i的所有条目0
在PyTorch中执行此操作的最佳方法是什么?
答案 0 :(得分:2)
您正在寻找的实际上是batched_index_select
,我之前寻找过这样的功能,但是在PyTorch中找不到可以完成任务的任何本机功能。但我们可以简单地使用:
A = torch.randn(10, 16, 5)
index = torch.from_numpy(numpy.random.randint(0, 16, size=10))
B = torch.stack([a[i] for a, i in zip(A, index)])
您可以看到讨论here。您还可以查看batched_index_select库中提供的功能AllenNLP。我很乐意知道是否有更好的解决方案。