tf.train.Saver
如何工作?
我了解到Tensorflow
图定义了计算,并且会话允许执行图或部分图。我还了解到Tensorflow变量仅在会话内有效。因此,我必须通过tf.train.Saver
将模型保存在会话中。
我天真的想法是,如果我在会话中运行部分图并使用tf.train.Saver
保存,那么我将只能保存在会话中运行的图和变量的一部分。
但是,在运行下面的玩具示例之后,无论我运行图形的哪一部分,所有变量都将被保存。
为什么会这样?
folders = ['a', 'b', 'c']
a = tf.get_variable('a', [3], initializer = tf.constant_initializer([1, 1, 1]))
b = tf.add(a, 3, name = 'b')
c = tf.add(b, 4, name = 'c')
saver = tf.train.Saver()
init = tf.global_variables_initializer()
print()
with tf.Session() as sess:
sess.run(init)
print(sess.run(a))
saver.save(sess, "only_a/model.ckpt")
with tf.Session() as sess:
sess.run(init)
print(sess.run(b))
saver.save(sess, "only_b/model.ckpt")
with tf.Session() as sess:
sess.run(init)
print(sess.run(c))
saver.save(sess, "only_c/model.ckpt")
print()
for item in folders:
with tf.Session() as sess:
graph = tf.train.import_meta_graph(f"only_{item}/model.ckpt.meta")
graph.restore(sess, f"only_{item}/model.ckpt")
print(f"The graph {item} is loaded.")
for inner in folders:
try:
test = tf.get_default_graph().get_tensor_by_name(f"{inner}:0")
print(f"get tensor {inner}")
print(f"{inner} ==> {sess.run(test)}")
except Exception as e:
print(f"Failed to get tensor {inner}")
print(e)
print()