如何使用我自己的图像提供Cifar10训练模型并获得标签作为输出?

时间:2016-06-01 21:13:18

标签: python tensorflow

我正在尝试使用基于Cifar10 tutorial的训练模型,并希望提供 它带有外部图像32x32(jpg或png) 我的目标是能够将标签作为输出。 换句话说,我想为网络提供一个大小为32 x 32的单个jpeg图像,3个没有标签的通道作为输入,并且推理过程给我 tf.argmax(logits, 1)
基本上我希望能够在外部图像上使用经过训练的cifar10模型,并查看它将吐出的类。

我一直在尝试基于Cifar10教程做到这一点,不幸的是总是有问题。尤其是会话概念和批处理概念。

对Cifar10的任何帮助都将不胜感激。

到目前为止,这是编译问题的实现代码:

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import math
import time

import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
import tensorflow as tf

import cifar10
import cifar10_input
import os
import faultnet_flags
from PIL import Image

FLAGS = tf.app.flags.FLAGS

def evaluate():

  filename_queue = tf.train.string_input_producer(['/home/tensor/.../inputImage.jpg'])

  reader = tf.WholeFileReader()
  key, value = reader.read(filename_queue)

  input_img = tf.image.decode_jpeg(value)

  init_op = tf.initialize_all_variables()

# Problem in here with Graph / session
  with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1): 
      image = input_img.eval()

    print(image.shape)
    Image.fromarray(np.asarray(image)).show()

# Problem in here is that I have only one image as input and have no label and would like to have
# it compatible with the Cifar10 network
    reshaped_image = tf.cast(image, tf.float32)
    height = FLAGS.resized_image_size
    width = FLAGS.resized_image_size
    resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, width, height)
    float_image = tf.image.per_image_whitening(resized_image)  # reshaped_image
    num_preprocess_threads = 1
    images = tf.train.batch(
      [float_image],
      batch_size=128,
      num_threads=num_preprocess_threads,
      capacity=128)
    coord.request_stop()
    coord.join(threads)

    logits = faultnet.inference(images)

    # Calculate predictions.
    #top_k_predict_op = tf.argmax(logits, 1)

    # print('Current image is: ')
    # print(top_k_predict_op[0])

    # this does not work since there is a problem with the session
    # and the Graph conflicting
    my_classification = sess.run(tf.argmax(logits, 1))

    print ('Predicted ', my_classification[0], " for your input image.")


def main(argv=None):
  evaluate()

if __name__ == '__main__':
  tf.app.run() '''

2 个答案:

答案 0 :(得分:4)

首先是一些基础知识:

  1. 首先定义图形:图像队列,图像预处理,convnet推理,top-k精度
  2. 然后你创建一个tf.Session()并在其中工作:启动队列运行器,并调用sess.run()
  3. 以下是您的代码应该是什么样的

    # 1. GRAPH CREATION 
    filename_queue = tf.train.string_input_producer(['/home/tensor/.../inputImage.jpg'])
    ...  # NO CREATION of a tf.Session here
    float_image = ...
    images = tf.expand_dims(float_image, 0)  # create a fake batch of images (batch_size=1)
    logits = faultnet.inference(images)
    _, top_k_pred = tf.nn.top_k(logits, k=5)
    
    # 2. TENSORFLOW SESSION
    with tf.Session() as sess:
        sess.run(init_op)
    
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    
        top_indices = sess.run([top_k_pred])
        print ("Predicted ", top_indices[0], " for your input image.")
    

    编辑:

    正如@mrry建议的那样,如果您只需要处理单个图像,则可以删除队列运行程序:

    # 1. GRAPH CREATION
    input_img = tf.image.decode_jpeg(tf.read_file("/home/.../your_image.jpg"), channels=3)
    reshaped_image = tf.image.resize_image_with_crop_or_pad(tf.cast(input_img, width, height), tf.float32)
    float_image = tf.image.per_image_withening(reshaped_image)
    images = tf.expand_dims(float_image, 0)  # create a fake batch of images (batch_size = 1)
    logits = faultnet.inference(images)
    _, top_k_pred = tf.nn.top_k(logits, k=5)
    
    # 2. TENSORFLOW SESSION
    with tf.Session() as sess:
      sess.run(init_op)
    
      top_indices = sess.run([top_k_pred])
      print ("Predicted ", top_indices[0], " for your input image.")
    

答案 1 :(得分:0)

cifar10_eval.py中的原始源代码也可用于测试自己的单个图像,如以下控制台输出中所示

nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ python cifar10_eval.py --run_once True 2>/dev/null
[ -0.63916457  -3.31066918   2.32452989   1.51062226  15.55279636
-0.91585422   1.26451302  -4.11891603  -7.62230825  -4.29096413]
deer
nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ python cifar2bin.py matchbox.png input.bin 
nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ python cifar10_eval.py --run_once True 2>/dev/null
[ -1.30562115  12.61497402  -1.34208572  -1.3238833   -6.13368177
-1.17441642  -1.38651907  -4.3274951    2.05489922   2.54187846]
automobile
nbatfai@robopsy:~/Robopsychology/repos/gpu/tensorflow/tensorflow/models/image/cifar10$ 

和代码段

#while step < num_iter and not coord.should_stop():
# predictions = sess.run([top_k_op])
print(sess.run(logits[0]))
classification = sess.run(tf.argmalogits[0], 0))
cifar10classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(cifar10classes[classification])

#true_count += np.sum(predictions)
step += 1

# Compute precision @ 1.
precision = true_count / total_sample_count
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))

更多详情可在帖子How can I test own image to Cifar-10 tutorial on Tensorflow?

中找到