我想知道如何修改classify_image.py(来自this tutorial,以便我可以从另一个python脚本中导入它。我基本上希望它具有与之相同的功能,但不是提供图像路径并在终端中打印出响应,我想给出一个函数图像路径,并获得函数返回前5个结果及其概率。
我还没有找到这个问题的直接解决方案,但我意识到我的问题解决和搜索以前的答案是有限的,因为我遗憾的是还没有学习Tensorflow的基础知识。
当然,如果有另一个预先训练好的Tensorflow模型同样出色并满足我的要求,我很乐意使用它。
此致 本都
更新也许我应该稍微澄清一下:
我不想训练模型,只需使用经过预先训练的模型进行图像识别,在这种情况下有一个图像识别脚本,我可以将其作为模块导入到另一个python应用程序中。
我也尝试使用this tutorial的代码,但我也被卡在那里,在这种情况下,它包含了很多手动安装,我可能会在某些步骤中失败。 classify_image.py example的好处在于我让它在教程中按预期工作,所以我认为从那里开始使用它作为可插拔模块的步骤不应该那么大。
我尝试过(使用classify_image.py)将if __name__ = '__main__'
下面的行移动到main(_)
,以便在我从另一个脚本调用它们时执行它们但我仍然遇到问题。我主要遇到main(_)
函数的问题,它要我传递一个参数,并且从周围搜索我认为_
似乎是从cli获取输入时使用的某种占位符。所有的FLAGS东西似乎也与cli有关,这就是我想要摆脱的。我也不确定模型权重等是否正确保存,以便能够从另一个脚本中使用它。同样,在这一点上,我只想玩图像分类器,并希望进一步了解它背后的机器学习。对不起我对此基础知识缺乏了解!
classify_image.py:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import re
import sys
import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
FLAGS = None
# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self,
label_lookup_path=None,
uid_lookup_path=None):
if not label_lookup_path:
label_lookup_path = os.path.join(
FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
if not uid_lookup_path:
uid_lookup_path = os.path.join(
FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
"""Loads a human readable English name for each softmax node.
Args:
label_lookup_path: string UID to integer node ID.
uid_lookup_path: string UID to human-readable string.
Returns:
dict from integer node ID to human-readable string.
"""
if not tf.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path)
if not tf.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def run_inference_on_image(image):
"""Runs inference on an image.
Args:
image: Image file name.
Returns:
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = NodeLookup()
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
def maybe_download_and_extract():
"""Download and extract model tar file."""
dest_directory = FLAGS.model_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def main(_):
maybe_download_and_extract()
image = (FLAGS.image_file if FLAGS.image_file else
os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
run_inference_on_image(image)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# classify_image_graph_def.pb:
# Binary representation of the GraphDef protocol buffer.
# imagenet_synset_to_human_label_map.txt:
# Map from synset ID to a human readable string.
# imagenet_2012_challenge_label_map_proto.pbtxt:
# Text representation of a protocol buffer mapping a label to synset ID.
parser.add_argument(
'--model_dir',
type=str,
default='/tmp/imagenet',
help="""\
Path to classify_image_graph_def.pb,
imagenet_synset_to_human_label_map.txt, and
imagenet_2012_challenge_label_map_proto.pbtxt.\
"""
)
parser.add_argument(
'--image_file',
type=str,
default='',
help='Absolute path to image file.'
)
parser.add_argument(
'--num_top_predictions',
type=int,
default=5,
help='Display this many predictions.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
答案 0 :(得分:0)
1)第一个问题是关于如何返回预测值。 以下代码段对给定图像进行了预置:
dict_orig = {u'start_time': u'1437056839370', u'playback': {u'duration': u'873041'}, u'end_time': u'1437058474763', u'id': u'61581a89c0804655f3a49b0df54468405d2bd78a'}
dict_reduced = {k: dict_orig[k] for k in dict_orig.keys() & set(u'start_time', u'duration', u'end_time', u'id')}
您可以将结果保存在某些数据结构中并返回,而不是打印。默认情况下,如果您要将此行为更改为 top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
,则会返回5个最高预测值。
2)关于模型: 它有两个部分 -
但是如果您仍然希望使用自己的数据集训练您的系统,我会说最初使用imagenet进行训练,然后使用您自己的数据集训练最后一层(张量名称为' final_result ')。请找到tutorial。
答案 1 :(得分:0)
最后,我设法使用原始问题更新中提到的SO文章中的代码。我从所述SO问题的答案修改了附加im = 2*(im/255.0)-1.0
的代码,在我的计算机上修改了PIL的一些行以及将类转换为人类可读标签的函数(在github上找到),链接到下面的那个文件。我使它成为一个可调用的函数,它将一个图像列表作为输入,并输出一个标签列表和预测值。如果你想使用它,这就是你必须要做的:
git clone https://github.com/tensorflow/models/
您想要的模型。使用SO问题中的代码并进行一些修改,将其放在项目的.py文件中:
import tensorflow as tf
slim = tf.contrib.slim
import PIL as pillow
from PIL import Image
#import Image
from inception_resnet_v2 import *
import numpy as np
with open('imagenet1000_clsid_to_human.txt','r') as inf:
imagenet_classes = eval(inf.read())
def get_human_readable(id):
id = id - 1
label = imagenet_classes[id]
return label
checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
#Load the model
sess = tf.Session()
arg_scope = inception_resnet_v2_arg_scope()
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
def classify_image(sample_images):
classifications = []
for image in sample_images:
im = Image.open(image).resize((299,299))
im = np.array(im)
im = im.reshape(-1,299,299,3)
im = 2*(im/255.0)-1.0
predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
#print (np.max(predict_values), np.max(logit_values))
#print (np.argmax(predict_values), np.argmax(logit_values))
label = get_human_readable(np.argmax(predict_values))
predict_value = np.max(predict_values)
classifications.append({"label":label, "predict_value":predict_value})
return classifications
答案 2 :(得分:0)
在我的情况下,只需将[-FLAGS.num_top_predictions:]
替换为[-5:]
然后用目录替换其他FLAG并将图像放在文件上。