get_variable()无法识别tf.estimator的现有变量

时间:2018-11-26 11:26:46

标签: tensorflow tensorflow-estimator

这个问题已经问过here,区别是我的问题集中在Estimator上。

某些情况:我们已经使用estimator训练了一个模型,并获得了在Estimator input_fn中定义的一些变量,此函数将数据预处理为批量。现在,我们正在转向预测。在预测期间,我们使用相同的input_fn读取和处理数据。 但是错误提示变量(word_embeddings)不存在(变量存在于chkp图中),这是input_fn中相关的代码:

with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
    if mode == tf.estimator.ModeKeys.TRAIN:
        word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
        word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
                                          trainable=False,
                                          name="word_to_vec",
                                          dtype=tf.float32)
    else:
        word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)

基本上,当它处于预测模式时,将调用else以在检查点中加载变量。无法识别此变量表示a)范围使用不当; b)图形未还原。只要设置reuse,我认为范围就没有那么重要了。

我怀疑这是因为该图尚未在input_fn阶段恢复。通常,通过调用saver.restore(sess, "/tmp/model.ckpt") reference恢复图形。对估算器source code的调查并没有给我带来任何与恢复有关的信息,最好的镜头是MonitoredSession,它是培训的包装。最初的问题已经牵扯到很多东西了,不确定我是否走上了正确的道路,如果有人有任何见解,我会在这里寻求帮助。

我的问题的一行摘要:如何通过tf.estimatorinput_fnmodel_fn内恢复图形?

1 个答案:

答案 0 :(得分:1)

嗨,我认为您的错误之所以来是因为您没有在tf.get_variable中指定形状(预测时),即使您要恢复该变量,似乎也需要指定形状。

我使用简单的线性回归估计量进行了以下测试,该估计量仅需要预测x + 5

def input_fn(mode):
    def _input_fn():
        with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
            if mode == tf.estimator.ModeKeys.TRAIN:
                var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
                x_data = np.random.randn(1000)
                labels = x_data + 5
                return {'x':x_data}, labels
            elif mode == tf.estimator.ModeKeys.PREDICT:
                var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
                return {'x':[0,10,100,var_to_follow]}
    return _input_fn

featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')

此代码工作得很好,const的值为20,并且在我的测试集中使用它很有趣以确认:p

但是,如果删除shape = [],它会中断,您还可以提供另一个初始化程序,例如tf.constant(500),一切正常,将使用20。

通过跑步

model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)

preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))

您可以可视化图形,并且会看到a)范围确定是正常的,b)图形已恢复。

希望这会对您有所帮助。