如何在tensorflow中恢复模型使用的变量子集

时间:2017-05-16 11:54:24

标签: tensorflow restore

我有2个文件,learn.py来保存模型,learn_2.py来恢复模型(这里是tf.variable a)并初始化新的tf.variable b,但出了点问题,这里是错误的:

tensorflow.python.framework.errors_impl.NotFoundError: Key scope/bb not found in checkpoint

learn.py

import tensorflow as tf

with tf.variable_scope("scope"):
    a = tf.get_variable("aa", shape=[2,4])

sess = tf.Session()
#sess.run(tf.global_variables_initializer())
sess.run(tf.initialize_variables([a]))
saver = tf.train.Saver()
save_path = saver.save(sess, "./tmp/model.ckpt")
print "---"
print sess.run(a)

learn_2.py

import tensorflow as tf

with tf.variable_scope("scope"):
    a = tf.get_variable("aa", shape=[2,4])
    b = tf.get_variable("bb", shape=[2,4])
sess = tf.Session()
#sess.run(tf.global_variables_initializer())
#sess.run(tf.initialize_variables([b]))
saver = tf.train.Saver()
save_path = saver.restore(sess, "./tmp/model.ckpt")
sess.run(tf.initialize_variables([b]))
print "---"
print sess.run(a)
print sess.run(b)

1 个答案:

答案 0 :(得分:0)

在第二个脚本learn2.py中执行此操作。适合我!

import tensorflow as tf

with tf.variable_scope("scope"):
    a = tf.get_variable("aa", shape=[2,4])
    sess = tf.Session()
    saver = tf.train.Saver()
    save_path = saver.restore(sess, "/tmp/model.ckpt")
    b = tf.get_variable("bb", shape=[2,4])
    sess.run(tf.initialize_variables([b]))
    print "---"
    print sess.run(a)
    print sess.run(b)

脚本1的输出(与learn.py相同)

---
[[ 0.21811056  0.75089216  0.43180299 -0.36542225]
 [-0.11786985 -0.26542974  0.68785524 -0.57991886]]

脚本2的输出(如上所述)

---
[[ 0.21811056  0.75089216  0.43180299 -0.36542225]
 [-0.11786985 -0.26542974  0.68785524 -0.57991886]]
[[-0.62411451 -0.32599163  0.72495079 -0.09547448]
 [-0.59518242  0.51209545 -0.68833208 -0.03813028]]

说明:

您添加了变量" a"和" b"在第二个脚本中绘图。当您尝试恢复时,它将搜索当前图形中的所有变量(" a"&" b")。我的解决方案是

  1. 构建与您首先要恢复的图形相同的图形。
  2. 恢复图表 - 所有变量
  3. 将节点/图层添加到图形并仅初始化新添加的变量。
  4. 我希望这会有所帮助。