我目前正在将某些代码从tensorflow转换为pytorch,我遇到了tf.gather
func的问题,没有直接函数将其在pytorch中转换。
我想要做的基本上是索引,我有两个张量,特征张量形状为[minibatch, 60, 2]
,索引张量为[minibatch, 8]
,就像第一个张量是张量A
一样,第二个是B
。
在Tensorflow中,它直接用tf.gather(A, B, batch_dims=1)
转换
我如何在pytorch中实现这一目标?
我尝试了A[B]
索引。这似乎行不通
和A[0]B[0]
有效,但形状的输出为[8, 2]
我需要[minibatch, 8, 2]
如果我像[stack, 8, 2]
那样堆叠张量,它可能会起作用,但我不知道该怎么做
tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great
[minibatch, 8, 2]
的输出形状
答案 0 :(得分:0)
我认为您正在寻找torch.gather
out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))