TensorFlow中SavedModelBuilder.add_meta_graph的真正用途是什么?

时间:2017-10-05 09:45:53

标签: tensorflow

我一直在玩saved_model API片刻,直到我意识到两个SavedModelBuilder函数之间的二元性:add_meta_graphadd_meta_graph_and_variables

因为这些API名称似乎意味着第一个函数保存所有内容而第二个函数只保存图形。我错误地认为我可以为第二个函数提取子图以减小saved_model.pb文件的大小。

但实际上,即使变量保持相同的名称,元图也会失去链接权重数据的能力。

到目前为止,我觉得它似乎只对将标签添加到同一个图表有用,这是没用的,因为你可以直接添加它们的列表。

我很遗憾看到这个add_meta_graph函数有任何有趣的属性,有人可以启发我吗?

请参阅以下示例:

import os, time

import tensorflow as tf
import numpy as np

dir = os.path.dirname(os.path.realpath(__file__))
export_dir = dir + '/results/' + str(int(time.time()))
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

# We build our neural network and its training graph
with tf.variable_scope('placeholders'):
    x_plh = tf.placeholder(tf.float32, shape=[None, nb_features], name="x")
    y_plh = tf.placeholder(tf.int32, shape=[None, 1], name="y")

with tf.variable_scope('linear_NN'):
    W = tf.get_variable('W', dtype=tf.float32, shape=[nb_features, nb_classes], initializer=tf.random_normal_initializer(0.05))
    y_hat = tf.matmul(x_plh, W)

with tf.variable_scope('loss'):
    loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(y_plh, y_hat))

with tf.variable_scope('predictions'):
    preds = tf.cast(tf.argmax(tf.nn.softmax(y_hat), 1), tf.int32, name="preds")
    accuracies = tf.cast(tf.equal(preds, tf.squeeze(y_plh, 1)), tf.float32)
    accuracy = tf.reduce_mean(accuracies, name="accuracy")

with tf.variable_scope('optimiser'):
    global_step_t = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
    adam = tf.train.AdamOptimizer(1e-2)
    train_op = adam.minimize(loss, global_step=global_step_t)

# We train our model
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    ...

    # We add the graph and its variables to the saved_model
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING])


# Let's clean the graph to have only needed inference nodes
serve_graph_def = tf.graph_util.extract_sub_graph(
    tf.get_default_graph().as_graph_def(), 
    ['predictions/preds']
)
tf.reset_default_graph()
tf.import_graph_def(serve_graph_def, name="")
# One of another problem here, is that this function hasn't any useful check to the variable data
# just because I called the first one, I can now call this one.
builder.add_meta_graph(
    [tf.saved_model.tag_constants.SERVING]
    , signature_def_map={
        tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            tf.saved_model.signature_def_utils.predict_signature_def(
                inputs={'x': x_plh}
                , outputs={'out': preds}
            )       
    }
)
builder.save(as_text=True)

# We use a temporary graph to load our saved model
# Everything is working fine here
with tf.Session(graph=tf.Graph()) as sess: 
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
    g = tf.get_default_graph()
    x_plh = g.get_tensor_by_name("placeholders/x:0")
    y_plh = g.get_tensor_by_name("placeholders/y:0")
    accuracy = g.get_tensor_by_name("predictions/accuracy:0")
    acc = sess.run(accuracy, feed_dict={
        x_plh: val_x,
        y_plh: val_y
    })
    print("acc: %f" % acc)

# Now I want to load the simplified graph for inference, but of course
# the link to variables is missing (no more trainable_variables and variables collections)
# So we can't use it like that
# But then, what is the purpose of this add_meta_graph function??
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
    g = tf.get_default_graph()

    x_plh = g.get_tensor_by_name("placeholders/x:0")
    preds = g.get_tensor_by_name("predictions/preds:0")
    p = sess.run(preds, feed_dict={ x_plh: [[.1, .1, .1, .1, .1, .1]] })
    print("p: %f" % p)

1 个答案:

答案 0 :(得分:0)

在我的代码实验中,当您使用函数extract_sub_graph时,子图节点名称已更改,不再与之前保存的graph.eg相同,该节点 'linear_NN/W'将在子图中更改为'import/linear_NN/W'。因此,子图无法链接变量,因为名称已更改。