Tensorflow:如何保存变量并将其加载到其他变量?

时间:2018-11-03 08:04:25

标签: python tensorflow

假设我有两个相同的网络,AB。我保存了(使用Saver的网络A的先前状态,现在我想将其加载到网络B中(所有操作均在同一运行中发生)。我该怎么办?

1 个答案:

答案 0 :(得分:1)

让我提供一个例子。首先,让我们定义并保存一些变量:

import tensorflow as tf

v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmp.ckpt')

现在,让我们在新图中定义一些具有相同名称的变量,然后从检查点加载它们的值:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros(1), name='v1')
    v2 = tf.Variable(tf.zeros(1), name='v2')

    saver = tf.train.Saver(tf.trainable_variables())
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, './tmp.ckpt')
        print(sess.run([v1, v2]))

最后一行打印:

[array([1.], dtype=float32), array([2.], dtype=float32)]