带有检查点的Tensorflow simple_save

时间:2019-01-31 09:51:22

标签: python tensorflow

我正在尝试在训练时以不同的步骤保存模型。假设我想在5个时间段后保存。

此刻我正在使用:

tf.saved_model.simple_save(
            sess, model_folder, inputs, outputs
        )

这是一种魅力。尽管如此,我意识到每次迭代都节省了整个图形和权重,这具有很高的计算成本。

我想更新模型的权重,以使图形与以前的保存保持一致(因为在训练期间它没有变化)

我已经阅读了有关 tf.train.Saver 的信息,这似乎符合我的意图。但这迫使我指定所有要保存的变量,这不像 simple_save 方法那样实用。因此,我想知道是否存在以检查点方式使用 simple_save 的任何方式。

1 个答案:

答案 0 :(得分:1)

我认为您对tf.train.Saver的理解有误。您可以做一些简单的事情:

saver = tf.train.Saver()
with tf.Session() as sess:
    for e in range(epochs):
        ...
        if e % 5 == 0:
            saver.save(sess, "/path/where/to/save/model")

因此无需指定要保存的每个变量。