我有一个张量probs
,它的形状(None, None, 110)
代表LSTM中的(batch_size, sequence_length, 110)
。
我有另一个张量indices
,其形状为(None, None)
,其中包含要从probs
的三维标注中选择的元素的索引。
我想使用indices
来索引张量probs
。
相当于脾气暴躁的人
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
由于未知shape[0]
中的shape[1]
和probs
,因此无法选择tf.meshgrid()
。
我找到了tf.gather
,tf.gather_nd
和tf.batch_gather
,但它们似乎都没有按照我的意愿做。
有人知道该怎么做吗?
答案 0 :(得分:2)
您可以使用tf.gather_nd
来做到这一点:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
顺便说一句,在NumPy中,您可以使用np.take_along_axis
进行相同的操作:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]