我从here下载了张量流模型。现在我试图在python中加载它。根据{{3}},.pb文件可以像这样加载:
import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
fname = "20170512-110547.pb"
with tf.Session() as persisted_sess:
print("load graph")
with gfile.FastGFile(fname,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print("map variables")
persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
try:
saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister!
except:pass
print("load data")
saver.restore(persisted_sess, "checkpoint.data") # now OK
print(persisted_result.eval())
print("DONE")
但是我收到了错误
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-2-9995bff0b3cd> in <module>()
12 graph_def.ParseFromString(f.read())
13 persisted_sess.graph.as_default()
---> 14 tf.import_graph_def(graph_def, name='')
15 print("map variables")
16 persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
/net/hciserver03/storage/oblum/venvs/compvisgpu01/lib/python2.7/site-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
256 for node in graph_def.node:
257 # Set any default attr values that aren't present.
--> 258 op_def = op_dict[node.op]
259 for attr_def in op_def.attr:
260 key = attr_def.name
KeyError: u'FIFOQueueV2'