此问题与以下问题非常相似:How do you use freeze_graph.py in Tensorflow? 但是还没有人回答,我对这个问题有不同的看法。因此,我想提出一些意见。
我还试图加载.pb二进制文件,然后将其冻结。这是我尝试的代码。
让我知道这是否给您任何想法。这不会返回错误。它只是使我的jupyter笔记本崩溃了。
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
with tf.Session() as sess:
model_filename ='saved_model.pb' # binary .pb file
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read()) # reads binary
sm = saved_model_pb2.SavedModel()
print(sm)
sm.ParseFromString(data) # parses through the file
print(sm)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
output_graph = "frozen_graph.pb"
# Getting all output nodes for the frozen graph
output_nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
# This not working fully
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_nodes# The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
print(g_in)
LOGDIR='.'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
这段代码应该生成一个冻结的文件,但是我不完全了解tensorflow的保存机制。如果我从这段代码中冻结了图形部分,则会得到和events.out。张量板可以读取的文件。
答案 0 :(得分:4)
因此,经过很多绊脚石,我意识到我只是在加载元图。不是整个图都带有变量。这是执行此操作的代码:
def frozen_graph_maker(export_dir,output_graph):
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
output_nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
sess.graph_def,
output_nodes# The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
def main():
export_dir='/dir/of/pb/and/variables'
output_graph = "frozen_graph.pb"
frozen_graph_maker(export_dir,output_graph)
我意识到我只是在加载元图。如果有人可以证实我对失败的理解,我将非常乐意。使用compat.as_bytes,我只是将其作为元图加载。完成这种加载后,是否可以整合变量?还是应该坚持使用tf.saved_model.loader.load()
?
我的加载尝试是完全错误的,因为它甚至没有调用变量文件夹。
另一个问题:使用[n.name for n in tf.get_default_graph().as_graph_def().node]
,我将所有节点放入output_nodes中,我是否应该仅将最后一个节点放入?它仅适用于最后一个节点。有什么区别?
答案 1 :(得分:0)
一个更简单的解决方案如下:
import tensorflow as tf
pb_saved_model = "/Users/vedanshu/saved_model/"
_graph = tf.Graph()
with _graph.as_default():
_sess = tf.Session(graph=_graph)
model = tf.saved_model.loader.load(_sess, ["serve"], pb_saved_model)
with tf.gfile.GFile("/Users/vedanshu/frozen_graph/frozen.pb", "wb") as f:
f.write(model.SerializeToString())
如果您的save_model中包含变量,则可以将其转换为常量,如下所示:
import tensorflow as tf
pb_saved_model = "/Users/vedanshu/saved_model/"
OUTPUT_NAMES = ["fc2/Relu"]
_graph = tf.Graph()
with _graph.as_default():
_sess = tf.Session(graph=_graph)
model = tf.saved_model.loader.load(_sess, ["serve"], pb_saved_model)
graphdef = tf.get_default_graph().as_graph_def()
frozen_graph = tf.graph_util.convert_variables_to_constants(_sess,graphdef, OUTPUT_NAMES)
frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
with tf.gfile.GFile("/Users/vedanshu/frozen_graph/frozen.pb", "wb") as f:
f.write(frozen_graph)