我想使用一个以张量为输入并返回跨张量轴的最大值的索引的函数。我知道有一个功能完全相同的tf.argmax(),但是我如何自己实现(在实现某些自定义功能的情况下可能需要这样做)?
现在让我们假设函数仅将一维张量作为输入。因此,该功能必须具有以下签名:
argmax(
input, #input is a 1D tensor
name=None
)
我尝试过这样实现:
def argmax(input, name=None):
maxValue=0
maxIndex=0
for i in range(input.get_shape()[0]):
if input[i]>maxValue:
maxValue=input[i]
maxIndex=i
return maxIndex
但是,这不起作用,因为在构建阶段,这些值尚未初始化,因此我无法像上面的代码中那样比较两个值。那么,有没有办法写出诸如tf.argmax,tf.equal等自定义函数?
答案 0 :(得分:0)
嗯,一种简单的方法是这样:
idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]
示例:
import tensorflow as tf
with tf.Session() as sess:
input = tf.constant([1, 3, 4, 2, 1, 2])
idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]
print(sess.run(idx))
输出:
2