Tensorflow:从图形文件中获取预测(.pb文件)

时间:2017-10-07 18:59:56

标签: python tensorflow tensorboard

我正在使用图形文件(pb文件),此Tensorflow模型的目的是提供对某些图像的预测

我已经开发了一个加载图形文件的代码,但我不能使用stat会话。 可用的文件是: -

  • training_model_saved_model.pb
  • 变量
    • training_model_variables_variables.data-00000-的-00001
    • training_model_variables_variables.index

输出错误包含大型模型层列表。在这种情况下我能做什么,感谢任何帮助

这是我用来加载/运行模型的代码

import tensorflow as tf
import sys
import os



import matplotlib.image as mpimg
import matplotlib.pyplot as plt


from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile

export_dir = os.path.join("./", "variables/")
filename = "imgpsh_fullsize.jpeg"
raw_image_data = mpimg.imread(filename)

g = tf.Graph()
with tf.Session(graph=g) as sess:
   model_filename ='training_model_saved_model.pb'
   with gfile.FastGFile(model_filename, 'rb') as f:

        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        #print(sm)
        if 1 != len(sm.meta_graphs):
                print('More than one graph found. Not sure which to write')
                sys.exit(1)

        image_input= tf.import_graph_def(sm.meta_graphs[0].graph_def,name='',return_elements=["input"])
        #print(image_input)
        #saver =  tf.train.Saver()
        saver = tf.train.import_meta_graph(sm.meta_graphs[0].graph_def)
        '''
        print(image_input)

        x = g.get_tensor_by_name("input:0")

        print(x)
        '''
        saver.restore(sess,model_filename)

        predictions = sess.run(feed_dict={image: raw_image_data})
        print('###################################################')
        print(predictions)

错误存在

Traceback (most recent call last):
  File "model_Input-get.py", line 35, in <module>
    saver = tf.train.import_meta_graph(sm.meta_graphs[0].graph_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1691, in import_meta_graph
    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.py", line 553, in read_meta_graph_file
    if not file_io.file_exists(filename):
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/lib/io/file_io.py", line 252, in file_exists
    pywrap_tensorflow.FileExists(compat.as_bytes(filename), status)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/compat.py", line 65, in as_bytes
    (bytes_or_text,))
TypeError: Expected binary or unicode string, got node {
  name: "input"
  op: "Placeholder"
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: -1
          }
        }
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_STRING
    }
  }

1 个答案:

答案 0 :(得分:0)

您似乎将TensorFlow服务SavedModel格式与常规TensorFlow导出/恢复功能混合使用。

这是TensorFlow代码库中一个特别令人困惑的部分,因为这种格式在首次出现时没有详细记录 - 并且没有很多示例显示何时使用此格式与原始格式。

我的建议是:

  1. 切换到TF服务并继续使用SavedModel格式或
  2. 坚持原始导出/恢复模式格式。