tf.train.Saver如何保存图形和变量?

时间:2019-04-29 01:51:37

标签: python tensorflow

tf.train.Saver如何工作?

我了解到Tensorflow图定义了计算,并且会话允许执行图或部分图。我还了解到Tensorflow变量仅在会话内有效。因此,我必须通过tf.train.Saver将模型保存在会话中。

我天真的想法是,如果我在会话中运行部分图并使用tf.train.Saver保存,那么我将只能保存在会话中运行的图和变量的一部分。

但是,在运行下面的玩具示例之后,无论我运行图形的哪一部分,所有变量都将被保存。

为什么会这样?

folders = ['a', 'b', 'c']
a = tf.get_variable('a', [3], initializer = tf.constant_initializer([1, 1, 1]))
b = tf.add(a, 3, name = 'b')
c = tf.add(b, 4, name = 'c')

saver = tf.train.Saver()
init = tf.global_variables_initializer()

print()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(a))
    saver.save(sess, "only_a/model.ckpt")
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(b))
    saver.save(sess, "only_b/model.ckpt")
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(c))
    saver.save(sess, "only_c/model.ckpt")

print()

for item in folders:
    with tf.Session() as sess:
        graph = tf.train.import_meta_graph(f"only_{item}/model.ckpt.meta")
        graph.restore(sess, f"only_{item}/model.ckpt")
        print(f"The graph {item} is loaded.")
        for inner in folders:
            try:
                test = tf.get_default_graph().get_tensor_by_name(f"{inner}:0")
                print(f"get tensor {inner}")
                print(f"{inner} ==> {sess.run(test)}")
            except Exception as e:
                print(f"Failed to get tensor {inner}")
                print(e)
    print()

0 个答案:

没有答案