如何从.pb文件导入模型

时间:2020-04-06 04:04:50

标签: tensorflow2.0

'''

with tf.Session() as sess:
    model_filename="./model/skipGram-word2Vec/saved_model.pb"
    with gfile.FastGFile(model_filename,'rb') as f:
        graph_def=tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        result=tf.import_graph_def(graph_def,name='')
        print(sess.run(result))

''' 然后,发生错误: DecodeError:标记中的电线类型错误。

1 个答案:

答案 0 :(得分:0)

这是从tensorflow 2.0中的.pb文件加载模型的方法

import tensorflow as tf

GRAPH_PB_PATH = './frozen_model.pb'
with tf.compat.v1.Session() as sess:
   print("load graph")
   with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.compat.v1.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)