我想保存Tensorflow(0.12.0)模型,包括图形和变量值,然后加载并执行它。我已经阅读了文档和其他帖子,但无法使基础工作。我正在使用this page in the Tensorflow docs中的技术。代码:
保存一个简单的模型:
myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0 = tf.train.Saver()
saver0.save(sess, './myModel.ckpt')
saver0.export_meta_graph('./myModel.meta')
稍后,加载并执行模型:
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, './myModel.meta')
print sess.run(myVar)
问题1:保存代码似乎有效,但加载代码会产生此错误:
W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
如何解决此问题?。
问题2:我将此行包含在TF文档中的模式中...
tf.add_to_collection('modelVariables', myVar)
......但为什么这条线必要?默认情况下expert_meta_graph
不会导出整个图表吗?如果没有,那么在保存之前是否需要将图中的每个变量添加到集合中?或者我们只是将那些将在恢复后访问的变量添加到集合中?
---------------------- 2017年1月12日更新 -------------- ---------------
部分成功基于Kashyap的建议,但仍然存在一个谜。只有当我包含包含tf.add_to_collection
和tf.get_collection
的行时,下面的代码才能但。如果没有这些行,'load'模式会在最后一行引发错误:
NameError: name 'myVar' is not defined
。我的理解是默认情况下Saver.save
保存并恢复图中的所有变量,那么为什么需要指定将在集合中使用的变量的名称?我认为这与将Tensorflow的变量名称映射到Python名称有关,但这里的游戏规则是什么?对于哪些变量需要这样做?
mode = 'load' # or 'save'
if mode == 'save':
myVar = tf.Variable(7.1)
init_op = tf.global_variables_initializer()
saver0 = tf.train.Saver()
tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0.save(sess, './myModel')
if mode == 'load':
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, tf.train.latest_checkpoint('./'))
myVar = tf.get_collection('myVar')[0] ### WHY NECESSARY?
print sess.run(myVar)
答案 0 :(得分:1)
<强> 问题1 强>
此问题已经彻底解答here。您不必明确致电export_meta_graph
。致电save method
。这也将生成.meta
文件(因为save方法将在内部调用export_meta_graph
方法。)
例如
saver0.save(sess, './myModel.ckpt')
将生成myModel.ckpt
文件以及myModel.ckpt.meta
文件。
然后您可以使用
恢复模型with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.ckpt.meta')
saver1.restore(sess, './myModel')
print sess.run(myVar)
<强> 问题2 强>
集合用于存储自定义信息,例如学习率,您使用的正则化因子以及其他信息,这些信息将在您导出图形时存储。 Tensorflow本身定义了一些集合,例如&#34; TRAINABLE_VARIABLES&#34;用于获取您构建的模型的所有可训练变量。您可以选择导出图表中的所有集合,也可以指定要在export_meta_graph
函数中导出的集合。
是tensorflow将导出您定义的所有变量。但是,如果您需要任何其他需要导出到图表的信息,那么可以将它们添加到集合中。
答案 1 :(得分:1)
我一直试图弄清楚同样的事情,并且能够使用Supervisor
成功地做到这一点。它会自动加载所有变量和图形等。以下是文档 - https://www.tensorflow.org/programmers_guide/supervisor。以下是我的代码 -
sv = tf.train.Supervisor(logdir="/checkpoint', save_model_secs=60)
with sv.managed_session() as sess:
if not sv.should_stop():
#Do run/eval/train ops on sess as needed. Above works for both saving and loading
如您所见,这比使用Saver
对象和处理单个变量等简单得多,只要图表保持不变(我的理解是Saver
在我们想要时很方便为不同的图表重用预先训练的模型。)