Tensorflow:初始v3批处理

时间:2016-11-21 12:13:35

标签: python tensorflow

我正在使用Inception v3提取瓶颈张量。我的问题:我一次只能提供一张图片:

sess.graph.get_tensor_by_name("pool_3:0").eval(session=sess, feed_dict={'DecodeJpeg:0':single_image})

批量处理多个图像会加速相当多的事情,我想。建议使用here的解决方案,但我无法使其工作(使用tensorflow v0.10.0和0.11.0rc0进行测试,从http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz下载的初始模型)。

import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2
import numpy as np

def create_graph():
    with gfile.FastGFile('classify_image_graph_def.pb', 'rb') as f:
        graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

create_graph()

img = cv2.imread("some_img.jpg")
img = cv2.resize(img, (299,299), interpolation = cv2.INTER_CUBIC) 
# replicate image 10 times
img = np.array(10*[img]).astype('float')

with tf.Session() as sess:
    pooled_2 =  sess.graph.get_tensor_by_name("pool_3:0").eval(session=sess, feed_dict={'ResizeBilinear:0':img})

这给了我(似是而非的)错误消息:

Traceback (most recent call last):
  File "extract_bottlenecks_minimal.py", line 26, in <module>
    pooled_2 = sess.graph.get_tensor_by_name("pool_3:0").eval(session=sess, feed_dict={'ResizeBilinear:0':img})
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 559, in eval
    return _eval_using_default_session(self, feed_dict, self.graph, session)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3761, in _eval_using_default_session
    return session.run(tensors, feed_dict)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 717, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 894, in _run
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (10, 299, 299, 3) for Tensor u'ResizeBilinear:0', which has shape '(1, 299, 299, 3)'

我还发现this issue声称开始使用图像批处理,但每次输入输入时,此代码都需要设置 整个网络 (我想只设置一次网络。

感谢您的兴趣 - 非常感谢任何帮助; - )

1 个答案:

答案 0 :(得分:2)

(1)请尝试使用更新的Inception v3模型,您可以在此处找到:

http://download.tensorflow.org/models/image/imagenet/inception-v3-2016-03-01.tar.gz

这是我在2月份开放的bug后发布的版本,它应该支持评估路径上的批量维度。 imagenet_eval脚本应该可以使用它。

(2)您可能需要解码JPEG并在进入网络之前调整大小为299x299。我没有看过预训练的是否接受可变批量大小的解码,但它至少应该接受它到网络本身。

有关详细信息,请参阅image recognition tutorial.