如何在Tensorflow r12中按文件名恢复模型?

时间:2016-12-08 21:08:13

标签: tensorflow

我运行了分布式mnist示例: https://github.com/tensorflow/tensorflow/blob/r0.12/tensorflow/tools/dist_test/python/mnist_replica.py

虽然我已经设置了

saver = tf.train.Saver(max_to_keep=0)

在之前的版本中,与r11一样,我能够遍历每个检查点模型并评估模型的精度。这给了我一个精确度与全局步骤(或迭代)进度的图表。

在r12之前,张量流检查点模型保存在两个文件model.ckpt-1234model-ckpt-1234.meta中。可以通过传递model.ckpt-1234文件名来恢复模型,如saver.restore(sess,'model.ckpt-1234')

但是,我注意到在r12中,现在有三个输出文件model.ckpt-1234.data-00000-of-000001model.ckpt-1234.indexmodel.ckpt-1234.meta

我看到恢复文档说明应该给出/train/path/model.ckpt之类的路径来恢复而不是文件名。有没有办法一次加载一个检查点文件来评估它?我尝试传递model.ckpt-1234.data-00000-of-000001model.ckpt-1234.indexmodel.ckpt-1234.meta文件,但收到如下错误:

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open logdir/2016-12-08-13-54/model.ckpt-0.data-00000-of-00001: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

NotFoundError (see above for traceback): Tensor name "hid_b" not found in checkpoint files logdir/2016-12-08-13-54/model.ckpt-0.index [[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open logdir/2016-12-08-13-54/model.ckpt-0.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

我在OSX Sierra上运行,并通过pip安装了tensorflow r12。

任何指导都会有所帮助。

谢谢。

6 个答案:

答案 0 :(得分:8)

我也使用了Tensorlfow r0.12,我认为保存和恢复模型没有任何问题。以下是一个简单的代码,您可以尝试:

import tensorflow as tf

# Create some variables.
v1 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.

  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model

虽然在r0.12中,检查点存储在多个文件中,但您可以使用公共前缀恢复它,即' model.ckpt'在你的情况下。

答案 1 :(得分:5)

R12改变了检查点格式。您应该以旧格式保存模型。

import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
...
saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V1)
saver.save(sess, './model.ckpt', global_step = step)

根据TensorFlow v0.12.0 RC0’s release note

  

新的检查点格式成为tf.train.Saver中的默认格式。老V1   检查点继续可读;由write_version控制   参数,tf.train.Saver现在默认在新V2中写出   格式。它显着降低了所需的峰值内存和延迟   在恢复期间发生。

请参阅my blog中的详细信息。

答案 2 :(得分:3)

您可以像这样恢复模型:

saver = tf.train.import_meta_graph('./src/models/20170512-110547/model-20170512-110547.meta')
            saver.restore(sess,'./src/models/20170512-110547/model-20170512-110547.ckpt-250000'))

路径' / src / models / 20170512-110547 /'包含三个文件:

model-20170512-110547.meta
model-20170512-110547.ckpt-250000.index
model-20170512-110547.ckpt-250000.data-00000-of-00001

如果在一个目录中有多个检查点,例如:路径中有检查点文件 ./20170807-231648 /:

checkpoint     
model-20170807-231648-0.data-00000-of-00001   
model-20170807-231648-0.index    
model-20170807-231648-0.meta   
model-20170807-231648-100000.data-00000-of-00001   
model-20170807-231648-100000.index   
model-20170807-231648-100000.meta

你可以看到有两个检查点,所以你可以使用它:

saver =    tf.train.import_meta_graph('/home/tools/Tools/raoqiang/facenet/models/facenet/20170807-231648/model-20170807-231648-0.meta')

saver.restore(sess,tf.train.latest_checkpoint('/home/tools/Tools/raoqiang/facenet/models/facenet/20170807-231648/'))

答案 3 :(得分:1)

好的,我可以回答我自己的问题。我发现我的python脚本在我的路径中添加了一个额外的'/',所以我正在执行: saver.restore(SESS, '/路径/到/火车// model.ckpt-1234')

以某种方式导致张量流问题。

当我删除它时,请致电: saver.restore(SESS, '/路径/到/特里安/ model.ckpt-1234')

它按预期工作。

答案 4 :(得分:0)

我是TF的新手,遇到了同样的问题。在阅读了Yuan Ma的评论后,我将'.index'与'.data-00000-of-00001'文件一起复制到了同一'train \ ckpt'文件夹中。然后它工作了! 因此,恢复模型时.index文件就足够了。 我在Win7,r12上使用了TF。

答案 5 :(得分:0)

仅使用model.ckpt-1234

至少对我有用