使用预训练(Tensorflow)CNN对图像进行分类

时间:2017-02-27 16:02:58

标签: tensorflow neural-network classification conv-neural-network tf-slim

我已经在自己的数据集上训练了alexnet_v2,现在想在另一个应用程序中使用它。这应该非常简单,并且我已尝试以多种方式实现它,但要么我得到错误我无法解决,或者(在下面的代码的情况下)它会无限期地挂起。

理想情况下,我喜欢C ++(但C ++ API似乎不可靠,或者至少在很多地方都有过时的文档,所以python是可以接受的),我想分类大组图像(例如:为程序提供80个动物图像并返回其中是否有任何动物显示猫)。

我是否以正确的方式使用下面的代码?如果是这样,我该如何解决呢。

如果没有,是否有更好的方法的工作示例?

非常感谢。

import tensorflow as tf

#Using preprocessing and alexnet_v2 net from the slim examples

from nets import nets_factory
from preprocessing import preprocessing_factory

#Checkpoint file from training on binary dataset

checkpoint_path = '/home/ubuntu/tensorflow/models/slim/data/checkpoint.ckpt'

slim = tf.contrib.slim

number_of_classes = 2


image_filename = '/home/ubuntu/tensorflow/models/slim/data/images/neg_sample_123459.jpg'

image_filename_placeholder = tf.placeholder(tf.string)

image_tensor = tf.read_file(image_filename_placeholder)

image_tensor = tf.image.decode_jpeg(image_tensor, channels=3)

image_batch_tensor = tf.expand_dims(image_tensor, axis=0)

#Use slim's alexnet_v2 implementation

network_fn = nets_factory.get_network_fn('alexnet_v2',num_classes=2,is_training=False)

#Use inception preprocessing

preprocessing_name = 'inception'
image_preprocessing_fn= preprocessing_factory.get_preprocessing(preprocessing_name,is_training=False)

image_tensor=image_preprocessing_fn(image_tensor,network_fn.default_image_size,network_fn.default_image_size)

label=3
images,labels=tf.train.batch(
    [image_tensor,label],
    batch_size=2,
    num_threads=1,
    capacity=10)

pred,_=network_fn(images)

initializer = tf.local_variables_initializer()

init_fn=slim.assign_from_checkpoint_fn(
    checkpoint_path,
    slim.get_model_variables('alexnet_v2'))

with tf.Session() as sess:

    sess.run(initializer)
    init_fn(sess)
    tf.train.start_queue_runners(sess)
    image_np, pred_np = sess.run([image_tensor, pred], feed_dict={image_filename_placeholder: image_filename})

编辑:以粗体添加行后,程序不再挂起。但是我收到占位符错误:

  

InvalidArgumentError:您必须为占位符张量提供值   '占位符'用dtype string [[Node:Placeholder =   Placeholderdtype = DT_STRING,shape = [],   _device =" /作业:本地主机/复制:0 /任务:0 / CPU:0"]]

我已经仔细检查了拼写,据我所知,我正确地喂它。怎么了?

1 个答案:

答案 0 :(得分:0)

tf.train.batch()函数使用后台线程来预取示例,但是您需要添加一个显式命令(tf.train.start_queue_runners(sess))来启动这些线程。按如下方式重写代码的最后一部分应该会阻止它挂起:

with tf.Session() as sess:
  sess.run(initializer)
  init_fn(sess)

  # Starts background threads for input preprocessing.
  tf.train.start_queue_runners(sess)

  image_np, pred_np = sess.run(
      [image_tensor, pred],
      feed_dict={image_filename_placeholder: image_filename})