如何在多个模型中使用tensorflow saver?

时间:2016-10-21 01:21:48

标签: python tensorflow

我在理解正确使用tf.train.Saver

方面遇到了很多麻烦

我有一个会话,我创建了几个不同的独立网络模型。所有模型都经过培训,我保存了性能最佳的网络供以后使用。

但是,当我尝试稍后恢复模型时,我收到一个错误,似乎表明某些变量未被保存或恢复:

NotFoundError: Tensor name "Network_8/train/beta2_power" not found in checkpoint files networks/network_0.ckpt

出于某种原因,当我尝试加载Network_0的变量时,我被告知我需要Network_8的变量信息。

确保我只从多网络会话中保存/恢复正确的变量的最佳方法是什么?

似乎我的问题的一部分是,虽然我为每个网络创建了一个我想要保存的变量的dict对象(权重和偏差),当我设置一个优化器,例如AdamOptimizer时,tensorflow自动创建额外的变量,需要初始化。如果您使用tf.train.Saver保存所有变量并且您只有一个网络,这很好,但是我正在训练多个网络并且只保存最佳结果。我不确定如何指定自动添加到我的dict中用于保存的变量。

1 个答案:

答案 0 :(得分:1)

我的解决方案是在原始模型和新模型(即Network_0和Network_8)中创建一个具有相同张量名称的part_saver,它只恢复所需的变量。

part_saver = tf.train.Saver({"W":w,"b":b,...})

在恢复部分模型之前,初始化Network_8中的所有变量。