需要保存哪些内容才能在TensorFlow

时间:2017-02-21 06:57:07

标签: machine-learning tensorflow

我已经开始探索TensorFlow库并尝试使用MNIST数据的图像分类example。我希望在训练阶段结束后将模型存储在文件中,以便我可以在需要时使用它。我已经检查了this link,它告诉我们如何将TensorFlow中的值保存到任何文件中并读取它。到目前为止,我能够使用pickle将一些变量从脚本保存到文件中,如链接中所示。但是,我无法掌握文件中需要保存的内容以存储模型的当前状态以供以后使用。 请有人可以通过存储模型和加载该模型的示例来解释该部分。

3 个答案:

答案 0 :(得分:2)

要在Tensorflow中保存和恢复变量,需要执行以下操作。

1)要保存和恢复的变量列表 2)tf.train.Saver

通常,1)通过

实现
# To save and restore whole tf variables
all_vars = tf.global_variables()

,或者

# To save and restore the specific tf variables using scope
all_vars = tf.global_variables()
model_vars = [k for k in all_vars if k.name.startswith("xxx")]
# "xxx" is the expected scope

然后,2)通过

实现
saver = tf.train.Saver(vars_list)
# vars_list is list of variables from above

最后,要保存变量,(使用名为'sess'的tf.Session()运行)

saver.save(sess, '/directory/to/chechpoint/file.ckpt')

并恢复它们,

saver.restore(sess, '/directory/to/chechpoint/file.ckpt')

答案 1 :(得分:1)

只能保存和恢复Variables。当您需要重用已保存的变量时,您需要首先通过创建神经网络并设置NN的参数(如图层编号,学习速率和丢失等)来构建图形。从检查点恢复的唯一值是变量在培训过程中定义。您可以查看任何示例,例如this one

总而言之,只有变量可以并且需要保存和恢复,神经网络配置和placeholders不能。

答案 2 :(得分:0)

首先,您应该查看此other question

TensorFlow实现了用于管理保存和恢复检查点的方法,特别是tf.train.saver类。查看官方文档here。检查点基本上将您的张量值(以及其他内容)存储在磁盘中。

引用文档:

  

检查点是专有格式的二进制文件,它将变量名称映射到张量值。检查检查点内容的最佳方法是使用Saver加载它。