我有三个问题:
假设我们有一个张量x = tf.constant([[0, 2, 1], [0, 0, 8], [2, 9, 0]])
,所需的输出将是max = 9, index = [2,1]
。
我已尝试使用tf.argmax
函数,但axis
的参数tf.argmax
不接受tuple
类型,例如:{axis = (0, 1)
{1}},仅在axis = 0
或axis = 1
时。
我已提到this问题,但他们没有回答我上面的问题。
另外,如何找到2D张量的最大值(比如前5个最大值),因为tf.reduce_max
只返回1?
在获得索引之后,如何使用这些索引来索引具有相同大小y
的另一个张量x
的值?
更新1 :正如Aldream指出,我的前两个问题是此问题的重复How to find the top k values in a 2-D tensor in tensorflow。不过,第三个问题仍然存在。
更新2 :对于第三个问题,如果有人遇到同样的问题,我们可以使用tf.gather
。