在检查点Tensorflow中找不到键<variable_name>

时间:2017-07-19 01:28:34

标签: python tensorflow

我正在使用Tensorflow v1.1并且我一直试图弄清楚如何使用我的EMA权重进行推理,但无论我做什么我都会不断收到错误

  

未找到:在检查点中找不到Key W / ExponentialMovingAverage

即使我循环并打印出所有tf.global_variables密钥存在

这是一个可重现的脚本,大量改编自Facenet's单元测试:

import tensorflow as tf
import numpy as np


tf.reset_default_graph()

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
opt_op = optimizer.minimize(loss)

# Track the moving averages of all trainable variables.
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
variables = tf.trainable_variables()
print(variables)
averages_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([opt_op]):
    train_op = tf.group(averages_op)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

saver = tf.train.Saver(tf.trainable_variables())

# Launch the graph.
sess = tf.Session()
sess.run(init)

# Fit the line.
for _ in range(201):
    sess.run(train_op)

w_reference = sess.run('W/ExponentialMovingAverage:0')
b_reference = sess.run('b/ExponentialMovingAverage:0')

saver.save(sess, os.path.join("model_ex1"))

tf.reset_default_graph()

tf.train.import_meta_graph("model_ex1.meta")
sess = tf.Session()

print('------------------------------------------------------')
for var in tf.global_variables():
    print('all variables: ' + var.op.name)
for var in tf.trainable_variables():
    print('normal variable: ' + var.op.name)
for var in tf.moving_average_variables():
    print('ema variable: ' + var.op.name)
print('------------------------------------------------------')

mode = 1
restore_vars = {}
if mode == 0:
    ema = tf.train.ExponentialMovingAverage(1.0)
    for var in tf.trainable_variables():
        print('%s: %s' % (ema.average_name(var), var.op.name))
        restore_vars[ema.average_name(var)] = var
elif mode == 1:
    for var in tf.trainable_variables():
        ema_name = var.op.name + '/ExponentialMovingAverage'
        print('%s: %s' % (ema_name, var.op.name))
        restore_vars[ema_name] = var

saver = tf.train.Saver(restore_vars, name='ema_restore')

saver.restore(sess, os.path.join("model_ex1")) # error happens here!

w_restored = sess.run('W:0')
b_restored = sess.run('b:0')

print(w_reference)
print(w_restored)
print(b_reference)
print(b_restored)

2 个答案:

答案 0 :(得分:8)

key not found in checkpoint错误表示变量存在于模型的内存中,但不存在于磁盘上的序列化检查点文件中。

您应该使用inspect_checkpoint tool来了解检查点中保存的张量,以及为什么没有保存某些指数移动平均值。

从你的repro示例中不清楚哪一行应该触发错误

答案 1 :(得分:5)

我想添加一种方法,最好在检查点使用训练过的变量。

请记住,保护程序var_list中的所有变量都应包含在您配置的检查点中。您可以通过以下方式检查保存程序:

print(restore_vars)

以及检查点中的那些变量:

vars_in_checkpoint = tf.train.list_variables(os.path.join("model_ex1"))

在你的情况下。

如果restore_vars全部包含在vars_in_checkpoint中,那么它不会引发错误,否则首先初始化所有变量:

all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
sess.run(tf.variables_initializer(all_variables))

所有变量都将初始化为检查点中或不在检查点中的那些变量,然后您可以过滤掉restore_vars中未包含在检查点中的那些变量(假设其名称中的ExponentialMovingAverage的所有变量都不在检查点中):< / p>

temp_saver = tf.train.Saver(
    var_list=[v for v in all_variables if "ExponentialMovingAverage" not in v.name])
ckpt_state = tf.train.get_checkpoint_state(os.path.join("model_ex1"), lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)

与从头开始训练模型相比,这可以节省一些时间。 (在我的场景中,恢复变量与开始时从头开始训练相比没有显着改善,因为所有旧的优化器变量都被放弃了。但是我认为它可以显着加速优化过程,因为它就像预先训练一些变量一样)

无论如何,有些变量可以像嵌入和一些图层等一样进行恢复。