如何独自实现tf.argmax?

时间:2018-09-17 15:01:05

标签: python tensorflow machine-learning

我想使用一个以张量为输入并返回跨张量轴的最大值的索引的函数。我知道有一个功能完全相同的tf.argmax(),但是我如何自己实现(在实现某些自定义功能的情况下可能需要这样做)?

现在让我们假设函数仅将一维张量作为输入。因此,该功能必须具有以下签名:

argmax(
    input, #input is a 1D tensor
    name=None
)

我尝试过这样实现:

def argmax(input, name=None):
    maxValue=0
    maxIndex=0
    for i in range(input.get_shape()[0]):
        if input[i]>maxValue:
            maxValue=input[i]
            maxIndex=i
    return maxIndex

但是,这不起作用,因为在构建阶段,这些值尚未初始化,因此我无法像上面的代码中那样比较两个值。那么,有没有办法写出诸如tf.argmax,tf.equal等自定义函数?

1 个答案:

答案 0 :(得分:0)

嗯,一种简单的方法是这样:

idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]

示例:

import tensorflow as tf

with tf.Session() as sess:
    input = tf.constant([1, 3, 4, 2, 1, 2])
    idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]
    print(sess.run(idx))

输出:

2