在Tensorflow 2.0中用另一个张量索引张量的第k个维度

时间:2020-06-04 09:41:03

标签: python numpy tensorflow indexing tensorflow2.0

我有一个张量probs,它的形状(None, None, 110)代表LSTM中的(batch_size, sequence_length, 110)。 我有另一个张量indices,其形状为(None, None),其中包含要从probs的三维标注中选择的元素的索引。

我想使用indices来索引张量probs

相当于脾气暴躁的人

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

由于未知shape[0]中的shape[1]probs,因此无法选择tf.meshgrid()。 我找到了tf.gathertf.gather_ndtf.batch_gather,但它们似乎都没有按照我的意愿做。

有人知道该怎么做吗?

1 个答案:

答案 0 :(得分:2)

您可以使用tf.gather_nd来做到这一点:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

顺便说一句,在NumPy中,您可以使用np.take_along_axis进行相同的操作:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]