我有一个变量v,想要对它应用移动平均线。我应用了以下步骤来保存它:
import tensorflow as tf
v=tf.Variable(0,dtype=tf.float32,name='v')
ema=tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op=ema.apply(tf.global_variables())
init=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
sess.run(tf.assign(v,10))
sess.run(maintain_averages_op)
saver.save(sess, 'C:/Users/User/PycharmProjects/Neural_Network.model.ckpt')
sess.run([v, ema.average(v)])
保存此会话后,我想要将其恢复并使用'v/ExponentialMovingAverage'
直接将v
分配给variables_to_restore
这是代码:
v=tf.Variable(0,dtype=tf.float32,name='v')
ema=tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
saver=tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess,'C:/Users/User/PycharmProjects/Neural_Network.model.ckpt')
sess.run(v)
然而,有NotFoundError:
NotFoundError (see above for traceback): Key v/ExponentialMovingAverage/ExponentialMovingAverage_1 not found in checkpoint
我对print(ema.variables_to_restore())
的输出感到有点困惑:
{'v/ExponentialMovingAverage/ExponentialMovingAverage_1': <tf.Variable 'v/ExponentialMovingAverage:0' shape=() dtype=float32_ref>, 'v_1/ExponentialMovingAverage_1': <tf.Variable 'v_1:0' shape=() dtype=float32_ref>, 'v_3/ExponentialMovingAverage': <tf.Variable 'v_3:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage_2': <tf.Variable 'v:0' shape=() dtype=float32_ref>, 'v_2/ExponentialMovingAverage_1': <tf.Variable 'v_2:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage_1': <tf.Variable 'v/ExponentialMovingAverage_1:0' shape=() dtype=float32_ref>, 'v_1/ExponentialMovingAverage': <tf.Variable 'v_1/ExponentialMovingAverage:0' shape=() dtype=float32_ref>, 'v/ExponentialMovingAverage/ExponentialMovingAverage': <tf.Variable 'v/ExponentialMovingAverage/ExponentialMovingAverage:0' shape=() dtype=float32_ref>, 'v_2/ExponentialMovingAverage': <tf.Variable 'v_2/ExponentialMovingAverage:0' shape=() dtype=float32_ref>}
为什么有这么多变量v_1, v_2
等?如何使用variables_to_restore
?
答案 0 :(得分:0)
如果在同一程序中两次调用tf.train.ExponentialMovingAverage(0.99),则将创建ExponentialMovingAverage_1。