pytorch张量索引

时间:2019-07-17 07:51:39

标签: python indexing pytorch

我目前正在将某些代码从tensorflow转换为pytorch,我遇到了tf.gather func的问题,没有直接函数将其在pytorch中转换。

我想要做的基本上是索引,我有两个张量,特征张量形状为[minibatch, 60, 2],索引张量为[minibatch, 8],就像第一个张量是张量A一样,第二个是B

在Tensorflow中,它直接用tf.gather(A, B, batch_dims=1)转换

我如何在pytorch中实现这一目标?

我尝试了A[B]索引。这似乎行不通

A[0]B[0]有效,但形状的输出为[8, 2]

我需要[minibatch, 8, 2]

的形状

如果我像[stack, 8, 2]那样堆叠张量,它可能会起作用,但我不知道该怎么做

tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great

[minibatch, 8, 2]的输出形状

1 个答案:

答案 0 :(得分:0)

我认为您正在寻找torch.gather

out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))