使用tf.train.save时,无法恢复Adam Optimizer的变量

时间:2017-11-22 10:50:04

标签: machine-learning tensorflow

当我尝试在tensorflow中恢复已保存的模型时出现以下错误:

 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key out_w/Adam_5 not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam_4 not found in checkpoint

我想我无法保存Adam Optimizer的变量。 任何修复?

2 个答案:

答案 0 :(得分:0)

考虑这个小实验:

import tensorflow as tf

def simple_model(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

def simple_model2(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1_x', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1_x', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './Checkpoint', global_step = 0)

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)      # Case 1
#model = simple_model2(X)    # Case 2
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.Saver().restore(sess, tf.train.latest_checkpoint('.'))

在案例1中,一切正常。但是在Case2中,你会得到像Key Layer1/b1_x not found in checkpoint这样的错误,这是因为模型中的变量名称是不同的(尽管两个变量的形状和数据类型相同)。确保变量在要还原的模型中具有相同的名称。

要检查检查点中存在的变量的名称,请选中answer

答案 1 :(得分:0)

由于检查点中只有部分可用的adam参数,因此当您不同时训练每个变量时,也会发生这种情况。

一种可能的解决方法是"重置"加载检查点后的亚当。为此,在创建保护程序时过滤与adam相关的变量:

vl = [v for v in tf.global_variables() if "Adam" not in v.name]
saver = tf.train.Saver(var_list=vl)

确保之后初始化全局变量。