Tensorflow:使用argmax切片张量

时间:2018-05-29 07:20:44

标签: python tensorflow

我有一个形状为tf.shape(t1) = [1, 1000, 400]的张量,我使用形状为max_ind = tf.argmax(t1, axis=-1)的{​​{1}}获取第三维上的最大值的索引。现在我有第二个张量,其形状与[1, 1000]相同:t1

我想使用tf.shape(t2) = [1, 1000, 400]的最大值索引切片t1,因此输出的格式为

t2

更直观的描述:结果张量应该与[1, 1000] 的结果相似,但最大值的位置在tf.reduce_max(t2, axis=-1)

1 个答案:

答案 0 :(得分:2)

你可以通过tf.gather_nd来实现这一点,尽管它并不是那么简单。例如,

shape = t1.shape.as_list()
xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
sliced_t2 = tf.gather_nd(t2, gather_ind)

另一方面,如果输入的形状未知为图形构建时间,则可以使用

shape = tf.shape(t1)
xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
                              indexing='ij'), axis=-1)

,其余部分与上述相同。