从model_dir部分加载tf.contrib.learn.Estimators(在自动编码器设置中仅加载解码器权重)

时间:2017-07-23 17:10:25

标签: tensorflow autoencoder

我正在尝试使用Estimators,而不是自己实施训练循环。我在MNIST数据上玩自动编码器。我有一个training_model_fn函数来构建一个包含输入,模型,损失,优化器和摘要的训练模型。我可以训练它,一切顺利,但是当我试图只加载解码器部分时 - 它失败了。

我希望解码器模型将encoded向量作为输入,并运行网络的相同解码部分(使用先前学习的权重)以在结束时生成decoded图像。

我创建了另一个decoded_model_fn函数,它与训练版共享一些代码并仅创建模型的相关部分,但是当我尝试加载Estimator时:

est = tf.contrib.learn.Estimator(model_fn=decoder_model_fn, model_dir=...)
est.predict(input_fn=...)

我收到以下错误:

...    
NotFoundError: Key ... not found in checkpoint ...
...

我假设Estimator正试图从检查点加载所有变量,显然我的解码器模型并不包含所有变量。

有谁知道如何从存储的会话中部分加载变量?我希望有一个ignore_unknowns标志,但找不到任何类似的东西。

我应该如何将Estimator用于自动编码器模型?

1 个答案:

答案 0 :(得分:0)

好的......回答自己以防将来有人接触到这个:

我的主要问题是我错误地没有将解码器放在与大型模型相同的variable_scope中。变量没有以相同的名称调用,因此无法加载。

发生在我身上的另一件事,可能会导致问题 - 当我训练原始模型时,我用numpy数组喂它,我确保将其转换为float32,因此权重存储为float32。当我提供解码器模型时,我使用了一些虚拟np数组,它们是float64,因此tensorflow抱怨它需要的数据多于检查点数据中的数据。花了我一些时间来弄清楚为什么会这样......