Tensorflow:如何获得正确的预测?

时间:2017-06-20 05:59:02

标签: python tensorflow

我加载之前已经保存过的模型,并输入图片来预测课程,但无论我输入什么图片,我仍然得到相同的预测,我不知道为什么以及如何解决。模型测试还可以。这是我的代码:

# -*- coding: utf-8 -*-

from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import re
import os
import cg


checkpoint='/home/vrview/tensorflow/example/char/tfrecords/try1/cg_try/'
img_dir='/home/vrview/tensorflow/example/char/test_abc/5093.jpg'
MODEL_SAVE_PATH = "/home/vrview/tensorflow/example/char/tfrecords/try1/cg_try/"

def get_one_image():
    image = Image.open(img_dir)    
    image = image.resize([56, 56])
    image = np.array(image)
    return image

def evaluate():
    image_array = get_one_image()
    with tf.Graph().as_default():
        image = tf.cast(image_array, tf.float32)
        image_1 = tf.image.per_image_standardization(image) 
        image_2 = tf.reshape(image_1, [1, 56, 56, 3])

        logit = cg.inference(image_2, evaluate, None)
        y = tf.nn.softmax(logit)    
        x = tf.placeholder(tf.float32, shape=[56, 56, 3])
        saver = tf.train.Saver()

        with tf.Session() as sess:
          ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
          if ckpt and ckpt.model_checkpoint_path:
               global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
               saver.restore(sess, ckpt.model_checkpoint_path)
               print('Loading success, global_step is %s' % global_step)
               prediction = sess.run(y, feed_dict={x: image_array})
               max_index = np.argmax(prediction)
               print ('max_index=%d'%max_index)
          else:
               print('No checkpoint file found')

def main(argv=None):
    evaluate()

if __name__ == '__main__':
    tf.app.run()

当我运行调试时,我得到的预测是[0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]。我的代码是10分类,输入图片是10个数字,大小是[56,56,3]。无论输入是什么,我都得到max_index为1。有人知道吗?非常感谢你!

1 个答案:

答案 0 :(得分:0)

问题是TensorFlow中的任何内容都没有使用x = tf.placeholder ...。您在那里声明了一个占位符,其他任何内容都没有使用x

回想一下,TensorFlow是一个计算图,当你调用sess.run时,它会执行必要的操作并返回值。我不了解您尝试使用cg.inference做什么,但考虑使用x作为某些TensorFlow操作的输入,并在TensorFlow中完成您需要做的大部分工作。