假设我有两个相同的网络,A
和B
。我保存了(使用Saver
的网络A
的先前状态,现在我想将其加载到网络B
中(所有操作均在同一运行中发生)。我该怎么办?
答案 0 :(得分:1)
让我提供一个例子。首先,让我们定义并保存一些变量:
import tensorflow as tf
v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './tmp.ckpt')
现在,让我们在新图中定义一些具有相同名称的变量,然后从检查点加载它们的值:
with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros(1), name='v1')
v2 = tf.Variable(tf.zeros(1), name='v2')
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './tmp.ckpt')
print(sess.run([v1, v2]))
最后一行打印:
[array([1.], dtype=float32), array([2.], dtype=float32)]