tf.argmax()返回意外结果

时间:2018-04-19 14:31:44

标签: python tensorflow mnist argmax

我最近正在制作一个基于张量流CNN,带有服务器接口的MNIST数据集的项目。

在预测部分,我使用 tf.argmax()来获取最大的logit,这将是预测值。但是,它返回的值似乎不是正确的答案。

预测功能大致如下:

    self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
    self._create_model()

    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state('../checkpoints/')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

    pred = tf.nn.softmax(self.logits)
    prediction = tf.argmax(pred, 1)
    logit = sess.run(pred)
    result = sess.run(prediction)[0]
    print(logit)
    print(result)

    return result

结果是:

127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]]
1

如您所见,logits显示最大数字的索引 5 ,但 tf.argmax()给了我 1 而不是。

顺便说一句,my model是基本的MNIST CNN模型,您可以在链接中看到。

那么 tf.argmax()函数发生了什么,或者我的代码出了什么问题?

1 个答案:

答案 0 :(得分:1)

由于您的logitpred)和resultprediction[0])来自两个不同的sess.run,我想知道是否有一些pred运行之间的差异。例如,图中有一个迭代器,它向模型发送输入。通过不同的运行,迭代器发送不同的数据,从而导致不同的预测。如果您将predictionsess.run放在同一个logit, result = sess.run((pred, prediction)) print(logit) print(result[0]) 中,将会很有趣:

$ bazel-bin/lm_1b/lm_1b_eval --mode sample --prefix "I love that I"  --pbtxt data/vocab-2016-09-10.txt --vocab_file data/vocab-2016-09-10.txt --ckpt 'data/ckpt-*'