TensorFlow Eager模式:如何从检查点恢复模型?

时间:2017-12-17 05:39:22

标签: python tensorflow deep-learning

我已经在TensorFlow eager模式中培训了CNN模型。现在我试图从检查点文件中恢复训练有素的模型但是没有取得任何成功。

我发现的所有示例(如下所示)都在讨论将检查点恢复到会话。但我需要的是将模型恢复到急切模式,即不创建会话。

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

基本上我需要的是:

tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)

然后我可以使用该模型进行预测。

有人可以帮忙吗?

更新

示例代码位于:mnist eager mode demo

我已经尝试按照@Jay Shah的回答中的步骤进行操作,但它几乎有效,但恢复的模型中没有任何变量。

tfe.save_network_checkpoint(model,'./test/my_model.ckpt')

Out[58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables

Out[72]:
[]

原始模型中有很多变量。:

model.variables

[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
 array([[[[ -8.25184360e-02,   6.77833706e-03,   6.97569922e-02,...

5 个答案:

答案 0 :(得分:7)

Eager Execution仍然是TensorFlow中的新功能,并未包含在最新版本中,因此并非所有功能都受支持,但幸运的是,从保存的检查点加载模型是。

您需要使用tfe.Saver类(这是tf.train.Saver类的一个瘦包装器),您的代码应如下所示:

saver = tfe.Saver([x, y])
saver.restore('/tmp/ckpt')

其中[x,y]表示要恢复的变量和/或模型列表。这应该与最初创建创建检查点的保护程序时传递的变量完全匹配。

可以找到更多详细信息,包括示例代码here,并且可以找到保护程序的API详细信息here

答案 1 :(得分:3)

好的,在逐行运行代码几个小时之后,我找到了一种方法将检查点恢复到新的TensorFlow Eager Mode模型。

使用TF Eager Mode MNIST

中的示例

步骤:

  1. 训练模型后,从训练过程中创建的检查点文件夹中找到最新的检查点(或所需的检查点)索引文件,例如'ckpt-25800.index'。在步骤5中恢复时,仅使用文件名“ckpt-25800”。

  2. 启动一个新的python终端并通过运行:

    启用TensorFlow Eager模式

    tfe.enable_eager_execution()

  3. 创建MNISTMOdel的新实例:

    model_new = MNISTModel()

  4. 通过运行一次虚拟列车过程来初始化model_new的变量。(这一步很重要。如果不首先初始化变量,它们就无法通过以下步骤恢复。但是我找不到另一种方法在Eager模式下初始化变量,而不是我在下面做的。)

    model_new(tfe.Variable(np.zeros((1,784),dtype=np.float32)), training=True)

  5. 使用步骤1中标识的检查点将变量恢复为model_new。

    tfe.Saver((model_new.variables)).restore('./tf_checkpoints/ckpt-25800')

  6. 如果恢复过程成功,您应该看到类似的内容:

    INFO:tensorflow:Restoring parameters from ./tf_checkpoints/ckpt-25800

  7. 现在检查点已成功恢复到model_new,您可以使用它来预测新数据。

答案 2 :(得分:1)

我想分享TFLearnDeep learning library featuring a higher-level API for TensorFlow。借助此库,您可以轻松save and restore模型。

保存模型

model = tflearn.DNN(net) #Here 'net' is your designed network model. 
#This is a sample example for training the model
model.fit(train_x, train_y, n_epoch=10, validation_set=(test_x, test_y), batch_size=10, show_metric=True)
model.save("model_name.ckpt")

恢复模型

model = tflearn.DNN(net)
model.load("model_name.ckpt")

有关tflearn的更多示例,您可以查看某个网站,例如...

答案 3 :(得分:1)

  • 首先,您可以通过执行以下操作将模型保存在检查点中:

saver.save(sess, './my_model.ckpt')

  • 在上面一行中,您将在“my_model.ckpt”检查点
  • 中保存会话

以下代码恢复模型

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './my_model.ckpt')
  • 当您将会话还原为模型时,您将从ckpt
  • 恢复模型

对于急切的保存模式:

tf.contrib.eager.save_network_checkpoint(sess,'./my_model.ckpt')

要恢复的急切模式:

tf.contrib.eager.restore_network_checkpoint(sess,'./my_model.ckpt')

sess是类Network的对象。可以保存和恢复类网络的任何对象。网络对象的快速解释: -

class TwoLayerNetwork(tfe.Network):
    def __init__(self, name):
        super(TwoLayerNetwork, self).__init__(name=name)
        self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,)))
        self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,)))
    def call(self, inputs):
        return self.layer_two(self.layer_one(inputs))

构造对象并调用Network后,变量列表 由跟踪的Layer创建的Network.variables可通过 sess = TwoLayerNetwork(name="net") # sess is object of Network output = sess(tf.ones([1, 8])) print([v.name for v in sess.variables]) ``` ================================================================= This example prints variable names, one kernel and one bias per `tf.layers.Dense` layer: ['net/dense/kernel:0', 'net/dense/bias:0', 'net/dense_1/kernel:0', 'net/dense_1/bias:0'] These variables can be passed to a `Saver` (`tf.train.Saver`, or `tf.contrib.eager.Saver` when executing eagerly) to save or restore the `Network` ================================================================= ``` tfe.save_network_checkpoint(sess,'./my_model.ckpt') # saving the model tfe.restore_network_checkpoint(sess,'./my_model.ckpt') # restoring 获得: 蟒

{{1}}

答案 4 :(得分:0)

使用tfe.Saver().save()保存变量:

for epoch in range(epochs):
    train_and_optimize()
    all_variables = model.variables + optimizer.variables()

    # save the varibles 
    tfe.Saver(all_variables).save(checkpoint_prefix)

然后使用tfe.Saver().restore()重新加载保存的变量:

tfe.Saver((model.variables + optimizer.variables())).restore(checkpoint_prefix)

然后,将使用已保存的变量加载模型,而无需像@Stefan Falk的答案中那样创建新变量。