Tensorflow:加载.pb文件,然后将其保存为冻结图问题

时间:2018-08-13 16:23:45

标签: python tensorflow

此问题与以下问题非常相似:How do you use freeze_graph.py in Tensorflow? 但是还没有人回答,我对这个问题有不同的看法。因此,我想提出一些意见。

我还试图加载.pb二进制文件,然后将其冻结。这是我尝试的代码。

让我知道这是否给您任何想法。这不会返回错误。它只是使我的jupyter笔记本崩溃了。

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.Session() as sess:
    model_filename ='saved_model.pb' # binary .pb file
    with gfile.FastGFile(model_filename, 'rb') as f:

        data = compat.as_bytes(f.read()) # reads binary
        sm = saved_model_pb2.SavedModel() 
        print(sm)
        sm.ParseFromString(data) # parses through the file
        print(sm)
        if 1 != len(sm.meta_graphs):
            print('More than one graph found. Not sure which to write')
            sys.exit(1)

        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
        output_graph = "frozen_graph.pb"

        # Getting all output nodes for the frozen graph
        output_nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
        # This not working fully   
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_nodes# The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
        print(g_in)
LOGDIR='.'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

这段代码应该生成一个冻结的文件,但是我不完全了解tensorflow的保存机制。如果我从这段代码中冻结了图形部分,则会得到和events.out。张量板可以读取的文件。

2 个答案:

答案 0 :(得分:4)

因此,经过很多绊脚石,我意识到我只是在加载元图。不是整个图都带有变量。这是执行此操作的代码:

def frozen_graph_maker(export_dir,output_graph):
    with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
    output_nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
    output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            sess.graph_def,
            output_nodes# The output node names are used to select the usefull nodes
    )       
    # Finally we serialize and dump the output graph to the filesystem
    with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
def main():
    export_dir='/dir/of/pb/and/variables'
    output_graph = "frozen_graph.pb"
    frozen_graph_maker(export_dir,output_graph)

我意识到我只是在加载元图。如果有人可以证实我对失败的理解,我将非常乐意。使用compat.as_bytes,我只是将其作为元图加载。完成这种加载后,是否可以整合变量?还是应该坚持使用tf.saved_model.loader.load()? 我的加载尝试是完全错误的,因为它甚至没有调用变量文件夹。

另一个问题:使用[n.name for n in tf.get_default_graph().as_graph_def().node],我将所有节点放入output_nodes中,我是否应该仅将最后一个节点放入?它仅适用于最后一个节点。有什么区别?

答案 1 :(得分:0)

一个更简单的解决方案如下:

import tensorflow as tf

pb_saved_model = "/Users/vedanshu/saved_model/"

_graph = tf.Graph()
with _graph.as_default():
    _sess = tf.Session(graph=_graph)
    model = tf.saved_model.loader.load(_sess, ["serve"], pb_saved_model)

with tf.gfile.GFile("/Users/vedanshu/frozen_graph/frozen.pb", "wb") as f:
    f.write(model.SerializeToString())

如果您的save_model中包含变量,则可以将其转换为常量,如下所示:

import tensorflow as tf

pb_saved_model = "/Users/vedanshu/saved_model/"
OUTPUT_NAMES = ["fc2/Relu"]

_graph = tf.Graph()
with _graph.as_default():
    _sess = tf.Session(graph=_graph)
    model = tf.saved_model.loader.load(_sess, ["serve"], pb_saved_model)
    graphdef = tf.get_default_graph().as_graph_def()
    frozen_graph = tf.graph_util.convert_variables_to_constants(_sess,graphdef, OUTPUT_NAMES)
    frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)

with tf.gfile.GFile("/Users/vedanshu/frozen_graph/frozen.pb", "wb") as f:
    f.write(frozen_graph)