TensorFlow:关于tf.argmax()和tf.equal()的问题

时间:2017-01-17 23:00:48

标签: tensorflow neural-network deep-learning

我正在学习TensorFlow,构建一个多层感知器模型。我正在研究一些例子:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/multilayer_perceptron.ipynb

然后我在下面的代码中提出了一些问题:

tf.argmax(prod,1)

我想知道tf.argmax(y,1)correct_prediction是什么意思并且确切地返回(类型和值)?并且y_test_prediction是变量而不是实数值吗?

最后,我们如何从tf会话中获取X_test数组(输入数据为{{1}}时的预测结果)?非常感谢!

3 个答案:

答案 0 :(得分:39)

tf.argmax(input, axis=None, name=None, dimension=None)

返回张量轴上具有最大值的索引。

输入是Tensor,轴描述输入Tensor的哪个轴要减少。对于矢量,使用axis = 0。

对于您的具体情况,让我们使用两个数组并演示此

pred = np.array([[31, 23,  4, 24, 27, 34],
                [18,  3, 25,  0,  6, 35],
                [28, 14, 33, 22, 20,  8],
                [13, 30, 21, 19,  7,  9],
                [16,  1, 26, 32,  2, 29],
                [17, 12,  5, 11, 10, 15]])

y = np.array([[31, 23,  4, 24, 27, 34],
                [18,  3, 25,  0,  6, 35],
                [28, 14, 33, 22, 20,  8],
                [13, 30, 21, 19,  7,  9],
                [16,  1, 26, 32,  2, 29],
                [17, 12,  5, 11, 10, 15]])

评估tf.argmax(pred, 1)会给出一个张量,其评估将给出array([5, 5, 2, 1, 3, 0])

评估tf.argmax(y, 1)会给出一个张量,其评估将给出array([5, 5, 2, 1, 3, 0])

tf.equal(x, y, name=None) takes two tensors(x and y) as inputs and returns the truth value of (x == y) element-wise. 

按照我们的示例,tf.equal(tf.argmax(pred, 1),tf.argmax(y, 1))会返回一个张量,其评估将给出array(1,1,1,1,1,1)

correct_prediction是一个张量,其评估将给出一个0和1的一维数组

y_test_prediction可以通过执行pred = tf.argmax(logits, 1)

获得

可以通过以下链接访问tf.argmax和tf.equal的文档。

tf.argmax()https://www.tensorflow.org/api_docs/python/math_ops/sequence_comparison_and_indexing#argmax

tf.equal()https://www.tensorflow.org/versions/master/api_docs/python/control_flow_ops/comparison_operators#equal

答案 1 :(得分:14)

阅读文档:

tf.argmax

  

返回张量轴上具有最大值的索引。

tf.equal

  

以元素方式返回(x == y)的真值。

tf.cast

  

将张量转换为新类型。

tf.reduce_mean

  

计算张量维度的元素平均值。

现在您可以轻松解释它的作用。您的y是单热编码的,因此它有一个1而其他所有都为零。您的pred表示课程的概率。因此argmax找到最佳预测和正确值的位置。之后,检查它们是否相同。

所以现在你的correct_prediction是一个True / False值的向量,其大小等于你想要预测的实例数。您将其转换为浮点数并取平均值。

实际上这部分在评估模型部分的TF tutorial中得到了很好的解释

答案 2 :(得分:1)

tf.argmax(输入,轴=无,名称=无,维度=无)

  

返回张量轴上具有最大值的索引。

对于特定情况,它会将pred作为input1 axis作为[2.11,1.0021,3.99,4.32]的参数。轴描述输入张量的哪个轴要减少。对于矢量,使用axis = 0。

示例:给定列表3 argmax将返回accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")),这是最高值的索引。

correct_prediction 是一个将在稍后评估的张量。它不是常规的python变量。它包含以后计算值的必要信息。 对于此特定情况,它将成为另一个张量eval的一部分,并将由accuracy.eval({x: X_test, y: y_test_onehot})上的correct_prediction进行评估。

y_test_prediction 应该是您的<%= form_tag update_multiple_homeworks_path(:revision_active => true), method: :put do %> .... <td><%= check_box_tag "homework[id][]", homework.id, checked = false %></td> .... <%= submit_tag "Add to Bucket", :remote => true, :class => 'smallMobileRight button text-left' %> 张量。