如何在模型A中存储训练的权重和偏差,用于张量流中的模型B.

时间:2016-02-18 16:25:06

标签: tensorflow

假设我构建了一个带有权重和偏差变量字典的神经网络A,并且我在神经网络A中训练后得到了权重和偏差的特定值。我想用这些特定值来代替{{ 1}}在神经网络中B.我怎样才能达到这个目的?我已经尝试了tf.Variable,但是,我不知道如何恢复另一个网络B(在另一个文件中)的权重和偏差。我还尝试使用tf.train.Saver()存储它们,但是,我遇到了另一个问题,即在使用pickle.dump恢复weightsbiases时,它说pickle.load类型不耐用。谁能帮我解决这个问题呢?

1 个答案:

答案 0 :(得分:1)

tf.train.Saver类应该对此有所帮助,尽管您可能需要使用一些可选参数来使其工作。

我们假设您的模型A看起来像这样,并且您已对其进行了培训并将其保存到名为"/tmp/model_a_ckpt"的文件中:

weights_a = tf.Variable(..., name="weights_a")
biases_a = tf.Variable(..., name="biases_a")
# ...

saver_a = tf.train.Saver()
# ...
saver_a.save(sess, "/tmp/model_a_ckpt")

...然后让我们说你的模特B看起来像这样:

weights_b = tf.Variable(..., name="weights_b")
biases_b = tf.Variable(..., name="biases_b")

要将检查点加载到模型B中,您必须创建一个保护程序映射检查点中的变量名称(即"weights_a""biases_a",因为它们默认为模型A中相应name个对象的tf.Variable属性到模型B中的变量:

saver_b = tf.train.Saver({"weights_a": weights_b, "biases_a": biases_b})
# ...
saver_b.restore(sess, "/tmp/model_a_ckpt")

运行saver_b.restore()后,模型B中的变量将具有在模型A中训练的值。