保存并加载Tensorflow模型

时间:2017-01-12 03:27:46

标签: tensorflow

我想保存Tensorflow(0.12.0)模型,包括图形和变量值,然后加载并执行它。我已经阅读了文档和其他帖子,但无法使基础工作。我正在使用this page in the Tensorflow docs中的技术。代码:

保存一个简单的模型:

myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print sess.run(myVar)
    saver0 = tf.train.Saver()
    saver0.save(sess, './myModel.ckpt')
    saver0.export_meta_graph('./myModel.meta')

稍后,加载并执行模型:

with tf.Session() as sess:
    saver1 = tf.train.import_meta_graph('./myModel.meta')
    saver1.restore(sess, './myModel.meta')
    print sess.run(myVar)

问题1:保存代码似乎有效,但加载代码会产生此错误:

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

如何解决此问题?。

问题2:我将此行包含在TF文档中的模式中...

tf.add_to_collection('modelVariables', myVar)

......但为什么这条线必要?默认情况下expert_meta_graph不会导出整个图表吗?如果没有,那么在保存之前是否需要将图中的每个变量添加到集合中?或者我们只是将那些将在恢复后访问的变量添加到集合中?

---------------------- 2017年1月12日更新 -------------- ---------------

部分成功基于Kashyap的建议,但仍然存在一个谜。只有当我包含包含tf.add_to_collectiontf.get_collection的行时,下面的代码才能。如果没有这些行,'load'模式会在最后一行引发错误: NameError: name 'myVar' is not defined。我的理解是默认情况下Saver.save保存并恢复图中的所有变量,那么为什么需要指定将在集合中使用的变量的名称?我认为这与将Tensorflow的变量名称映射到Python名称有关,但这里的游戏规则是什么?对于哪些变量需要这样做?

mode = 'load' # or 'save'
if mode == 'save':
    myVar = tf.Variable(7.1)
    init_op = tf.global_variables_initializer()
    saver0 = tf.train.Saver()
    tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
    with tf.Session() as sess:
        sess.run(init_op)
        print sess.run(myVar)
        saver0.save(sess, './myModel')
if mode == 'load':
    with tf.Session() as sess:
        saver1 = tf.train.import_meta_graph('./myModel.meta')
        saver1.restore(sess, tf.train.latest_checkpoint('./'))
        myVar = tf.get_collection('myVar')[0]  ### WHY NECESSARY?
        print sess.run(myVar)

2 个答案:

答案 0 :(得分:1)

<强> 问题1

此问题已经彻底解答here。您不必明确致电export_meta_graph。致电save method。这也将生成.meta文件(因为save方法将在内部调用export_meta_graph方法。)

例如

saver0.save(sess, './myModel.ckpt')

将生成myModel.ckpt文件以及myModel.ckpt.meta文件。

然后您可以使用

恢复模型
with tf.Session() as sess:
    saver1 = tf.train.import_meta_graph('./myModel.ckpt.meta')
    saver1.restore(sess, './myModel')
    print sess.run(myVar)

<强> 问题2

集合用于存储自定义信息,例如学习率,您使用的正则化因子以及其他信息,这些信息将在您导出图形时存储。 Tensorflow本身定义了一些集合,例如&#34; TRAINABLE_VARIABLES&#34;用于获取您构建的模型的所有可训练变量。您可以选择导出图表中的所有集合,也可以指定要在export_meta_graph函数中导出的集合。

是tensorflow将导出您定义的所有变量。但是,如果您需要任何其他需要导出到图表的信息,那么可以将它们添加到集合中。

答案 1 :(得分:1)

我一直试图弄清楚同样的事情,并且能够使用Supervisor成功地做到这一点。它会自动加载所有变量和图形等。以下是文档 - https://www.tensorflow.org/programmers_guide/supervisor。以下是我的代码 -

sv = tf.train.Supervisor(logdir="/checkpoint', save_model_secs=60)
    with sv.managed_session() as sess:
        if not sv.should_stop(): 
            #Do run/eval/train ops on sess as needed. Above works for both saving and loading

如您所见,这比使用Saver对象和处理单个变量等简单得多,只要图表保持不变(我的理解是Saver在我们想要时很方便为不同的图表重用预先训练的模型。)