如何在Tensorflow教程中将Imagenet(classify_image.py)预先训练的inception-v3模型作为模块导入?

时间:2017-02-20 12:48:25

标签: python tensorflow

我想知道如何修改classify_image.py(来自this tutorial,以便我可以从另一个python脚本中导入它。我基本上希望它具有与之相同的功能,但不是提供图像路径并在终端中打印出响应,我想给出一个函数图像路径,并获得函数返回前5个结果及其概率。

我还没有找到这个问题的直接解决方案,但我意识到我的问题解决和搜索以前的答案是有限的,因为我遗憾的是还没有学习Tensorflow的基础知识。

当然,如果有另一个预先训练好的Tensorflow模型同样出色并满足我的要求,我很乐意使用它。

此致 本都

更新也许我应该稍微澄清一下:

我不想训练模型,只需使用经过预先训练的模型进行图像识别,在这种情况下有一个图像识别脚本,我可以将其作为模块导入到另一个python应用程序中。

我也尝试使用this tutorial的代码,但我也被卡在那里,在这种情况下,它包含了很多手动安装,我可能会在某些步骤中失败。 classify_image.py example的好处在于我让它在教程中按预期工作,所以我认为从那里开始使用它作为可插拔模块的步骤不应该那么大。

我尝试过(使用classify_image.py)将if __name__ = '__main__'下面的行移动到main(_),以便在我从另一个脚本调用它们时执行它们但我仍然遇到问题。我主要遇到main(_)函数的问题,它要我传递一个参数,并且从周围搜索我认为_似乎是从cli获取输入时使用的某种占位符。所有的FLAGS东西似乎也与cli有关,这就是我想要摆脱的。我也不确定模型权重等是否正确保存,以便能够从另一个脚本中使用它。同样,在这一点上,我只想玩图像分类器,并希望进一步了解它背后的机器学习。对不起我对此基础知识缺乏了解!

classify_image.py:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

FLAGS = None

# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long


class NodeLookup(object):
  """Converts integer node ID's to human readable labels."""

  def __init__(self,
               label_lookup_path=None,
               uid_lookup_path=None):
    if not label_lookup_path:
      label_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
    if not uid_lookup_path:
      uid_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
    self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

  def load(self, label_lookup_path, uid_lookup_path):
    """Loads a human readable English name for each softmax node.
    Args:
      label_lookup_path: string UID to integer node ID.
      uid_lookup_path: string UID to human-readable string.
    Returns:
      dict from integer node ID to human-readable string.
    """
    if not tf.gfile.Exists(uid_lookup_path):
      tf.logging.fatal('File does not exist %s', uid_lookup_path)
    if not tf.gfile.Exists(label_lookup_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # Loads mapping from string UID to human-readable string
    proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
    uid_to_human = {}
    p = re.compile(r'[n\d]*[ \S,]*')
    for line in proto_as_ascii_lines:
      parsed_items = p.findall(line)
      uid = parsed_items[0]
      human_string = parsed_items[2]
      uid_to_human[uid] = human_string

    # Loads mapping from string UID to integer node ID.
    node_id_to_uid = {}
    proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
    for line in proto_as_ascii:
      if line.startswith('  target_class:'):
        target_class = int(line.split(': ')[1])
      if line.startswith('  target_class_string:'):
        target_class_string = line.split(': ')[1]
        node_id_to_uid[target_class] = target_class_string[1:-2]

    # Loads the final mapping of integer node ID to human-readable string
    node_id_to_name = {}
    for key, val in node_id_to_uid.items():
      if val not in uid_to_human:
        tf.logging.fatal('Failed to locate: %s', val)
      name = uid_to_human[val]
      node_id_to_name[key] = name

    return node_id_to_name

  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]


def create_graph():
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image):
  """Runs inference on an image.
  Args:
    image: Image file name.
  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = tf.gfile.FastGFile(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph()

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup()

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))


def maybe_download_and_extract():
  """Download and extract model tar file."""
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):
    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def main(_):
  maybe_download_and_extract()
  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  run_inference_on_image(image)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.
  parser.add_argument(
      '--model_dir',
      type=str,
      default='/tmp/imagenet',
      help="""\
      Path to classify_image_graph_def.pb,
      imagenet_synset_to_human_label_map.txt, and
      imagenet_2012_challenge_label_map_proto.pbtxt.\
      """
  )
  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )
  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

3 个答案:

答案 0 :(得分:0)

1)第一个问题是关于如何返回预测值。 以下代码段对给定图像进行了预置:

dict_orig = {u'start_time': u'1437056839370', u'playback': {u'duration': u'873041'}, u'end_time': u'1437058474763', u'id': u'61581a89c0804655f3a49b0df54468405d2bd78a'}
dict_reduced = {k: dict_orig[k] for k in dict_orig.keys() & set(u'start_time', u'duration', u'end_time', u'id')}

您可以将结果保存在某些数据结构中并返回,而不是打印。默认情况下,如果您要将此行为更改为 top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1] for node_id in top_k: human_string = node_lookup.id_to_string(node_id) score = predictions[node_id] print('%s (score = %.5f)' % (human_string, score)) ,则会返回5个最高预测值。

2)关于模型: 它有两个部分 -

  1. 你需要像Imagenet一样拥有高质量的数据集。
  2. 假设您拥有这样的高质量数据集,那么培训开始的基础设施将需要非常强大的GPU。很多时间。
  3. 但是如果您仍然希望使用自己的数据集训练您的系统,我会说最初使用imagenet进行训练,然后使用您自己的数据集训练最后一层(张量名称为' final_result ')。请找到tutorial

答案 1 :(得分:0)

最后,我设法使用原始问题更新中提到的SO文章中的代码。我从所述SO问题的答案修改了附加im = 2*(im/255.0)-1.0的代码,在我的计算机上修改了PIL的一些行以及将类转换为人类可读标签的函数(在github上找到),链接到下面的那个文件。我使它成为一个可调用的函数,它将一个图像列表作为输入,并输出一个标签列表和预测值。如果你想使用它,这就是你必须要做的:

  1. 安装最新的Tensorflow版本(此时需要1.0,这是必需的)。
  2. git clone https://github.com/tensorflow/models/您想要的模型。
  3. 从我之前提到的SO问题中提出this checkpoint file(当然需要提取)在项目目录中。
  4. this text file(人类可读标签)放在项目目录中。
  5. 使用SO问题中的代码并进行一些修改,将其放在项目的.py文件中:

    import tensorflow as tf
    slim = tf.contrib.slim
    import PIL as pillow
    from PIL import Image
    #import Image
    from inception_resnet_v2 import *
    import numpy as np
    
    with open('imagenet1000_clsid_to_human.txt','r') as inf:
        imagenet_classes = eval(inf.read())
    
    def get_human_readable(id):
        id = id - 1
        label = imagenet_classes[id]
    
        return label
    
    checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
    
    #Load the model
    sess = tf.Session()
    arg_scope = inception_resnet_v2_arg_scope()
    input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])  
    with slim.arg_scope(arg_scope):
        logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_file)
    
    def classify_image(sample_images):
        classifications = []
        for image in sample_images:
            im = Image.open(image).resize((299,299))
            im = np.array(im)
            im = im.reshape(-1,299,299,3)
            im = 2*(im/255.0)-1.0
            predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
            #print (np.max(predict_values), np.max(logit_values))
            #print (np.argmax(predict_values), np.argmax(logit_values))
            label = get_human_readable(np.argmax(predict_values))
            predict_value = np.max(predict_values)
            classifications.append({"label":label, "predict_value":predict_value})
    
        return classifications
    

答案 2 :(得分:0)

在我的情况下,只需将[-FLAGS.num_top_predictions:]替换为[-5:]

然后用目录替换其他FLAG并将图像放在文件上。