我正在尝试加载从https://github.com/tensorflow/models/tree/master/official/resnet获取的经过训练的模型,但是当我尝试加载.pb
时,在ParseFromString
方法上遇到错误:
import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = '../resnet_v2_fp32_savedmodel_NHWC/1538687283/saved_model.pb'
with tf.gfile.FastGFile(GRAPH_PB_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
DecodeError: Error parsing message
我怎么了?
答案 0 :(得分:1)
我遇到了类似的问题,而不是使用gfile,而是使用了tf.saved_model.loader.load函数,如本文https://stackoverflow.com/a/46547595/4637693:
sess = tf.Session(graph=tf.Graph())
model = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './model')
graph_def = model.graph_def