如何在TensorFlow中编写argmax函数?

时间:2018-09-09 20:07:20

标签: python tensorflow reinforcement-learning

我不是在谈论tf.argmax,而是在数学意义上谈论argmax,例如给定一组离散值和一个函数,找到使它最大化的值。我目前有这样的东西

input = tf.placeholder(name='input')
Qhat = do_stuff_to(input) # e.g. tf.add(input, 3)

现在,我想定义另一个TensorFlow节点max_Qhat,它将使用Tensors数组作为其参数。它将把这些张量中的每个张量馈送到Qhat并返回产生最大值的张量。我该怎么做呢? (注意,我不是要运行 Qhat,所以没有session.run。我只是想定义一个对其求值的函数。 )

我到目前为止的代码:

inputs = tf.placeholder(name='inputs')
max_Qhat = ???

1 个答案:

答案 0 :(得分:1)

创建一个输入张量,例如尺寸为[k, ...](这是我们“最大化”其上的k输入张量的“数组”),然后在第一个轴/维度上计算Qhat op输入张量(例如,使用tf.map_fn),以使其返回维度为[k]的函数值的张量,然后确定该返回张量的最大值。

import tensorflow as tf;

sess = tf.InteractiveSession();

# define inputs
inputs = tf.constant([
    [1, 2, 3, 4, 5],
    [4, 3, 6, 2, 1],
    [9, 9, 9, 9, 9],
    [0, 1, 3, 5, 2]
], shape = (4, 5));

# the function/op that we want to compute for each input tensor
def some_op(t):
    return tf.reduce_sum(t);

# compute the function values
q = tf.map_fn(some_op, inputs);

# determine the index of the input tensor that maximizes the function
index = sess.run(tf.argmax(q, axis = 0));
maximizer = sess.run(inputs[index]);