我有一个形状为tf.shape(t1) = [1, 1000, 400]
的张量,我使用形状为max_ind = tf.argmax(t1, axis=-1)
的{{1}}获取第三维上的最大值的索引。现在我有第二个张量,其形状与[1, 1000]
相同:t1
。
我想使用tf.shape(t2) = [1, 1000, 400]
的最大值索引切片t1
,因此输出的格式为
t2
更直观的描述:结果张量应该与[1, 1000]
的结果相似,但最大值的位置在tf.reduce_max(t2, axis=-1)
答案 0 :(得分:2)
你可以通过tf.gather_nd
来实现这一点,尽管它并不是那么简单。例如,
shape = t1.shape.as_list()
xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
sliced_t2 = tf.gather_nd(t2, gather_ind)
另一方面,如果输入的形状未知为图形构建时间,则可以使用
shape = tf.shape(t1)
xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
indexing='ij'), axis=-1)
,其余部分与上述相同。