Tensorflow:保存/恢复会话,检查点,元图

时间:2017-03-16 10:56:45

标签: python tensorflow neural-network

我一直试图恢复tensorflow中的模型,但是我一直在 当我尝试导入元图时遇到一些问题:

这是我导入元图的代码:

#Create a clean graph and import MetaGraphDef nodes
new_graph = tf.Graph()
with tf.Session(graph=new_graph) as sess:
    # Import the previously exported metagraph
    saver = tf.train.import_meta_graph('/tmp/saver-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

在我的Model类中,我按如下方式指定了占位符和集合:

    """Place Holders"""
    self.input = tf.placeholder(tf.float32, [None, sl], name = 'input')
    self.labels = tf.placeholder(tf.int64, [None], name = 'labels')
    self.keep_prob = tf.placeholder("float", name= 'Drop_out_keep_prob')
    tf.add_to_collection('vars', self.input)
    tf.add_to_collection('vars', self.labels)
    tf.add_to_collection('vars', self.keep_prob)

我训练我的模型如下:

saver = tf.train.Saver(tf.global_variables())
# Session time
sess = tf.Session() # without context manager, close the session later.
writer = tf.summary.FileWriter("/tmp/model/log_tb", sess.graph) # Writer for tensorboard
sess.run(model.init_op)
  

self.init_op = tf.global_variables_initializer()

使用这三个不同的选项导出,包括未记录的export_scoped_meta_graph

# Export the model to /tmp/my-model.meta.
scoped_meta = meta_graph.export_scoped_meta_graph(filename='/tmp/scoped.meta')
meta_graph_def = tf.train.export_meta_graph(filename='/tmp/my-model.meta')
saver.save(sess, '/tmp/saver-model')

这是我尝试在Windows 10下运行时遇到的错误:

E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "BestSplits" device_type: "CPU"') for unknown op: BestSplits
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "CountExtremelyRandomStats" device_type: "CPU"') for unknown op: CountExtremelyRandomStats
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "FinishedNodes" device_type: "CPU"') for unknown op: FinishedNodes
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "GrowTree" device_type: "CPU"') for unknown op: GrowTree
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "ReinterpretStringToFloat" device_type: "CPU"') for unknown op: ReinterpretStringToFloat
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "SampleInputs" device_type: "CPU"') for unknown op: SampleInputs
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "ScatterAddNdim" device_type: "CPU"') for unknown op: ScatterAddNdim
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "TopNInsert" device_type: "CPU"') for unknown op: TopNInsert
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "TopNRemove" device_type: "CPU"') for unknown op: TopNRemove
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "TreePredictions" device_type: "CPU"') for unknown op: TreePredictions
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel ('op: "UpdateFertileSlots" device_type: "CPU"') for unknown op: UpdateFertileSlots
TypeError: expected bytes, NoneType found

During handling of the above exception, another exception occurred:


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: expected bytes, NoneType found

During handling of the above exception, another exception occurred:

SystemError                               Traceback (most recent call last)
<ipython-input-37-60792895b01c> in <module>()
      6     #saver = tf.train.import_meta_graph('/tmp/saver-model.meta')
      7     saver = tf.train.import_meta_graph('/tmp/my-model.meta')
----> 8     saver.restore(sess, tf.train.latest_checkpoint('./'))

c:\users\carlos\anaconda3\lib\site-packages\tensorflow\python\training\saver.py in restore(self, sess, save_path)
   1437       return
   1438     sess.run(self.saver_def.restore_op_name,
-> 1439              {self.saver_def.filename_tensor_name: save_path})
   1440 
   1441   @staticmethod

c:\users\carlos\anaconda3\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

c:\users\carlos\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

c:\users\carlos\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

c:\users\carlos\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

c:\users\carlos\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

SystemError: <built-in function TF_Run> returned a result with an error set

尝试在debian下运行时:

I tensorflow/core/common_runtime/gpu/gpu_device.cc:906] DMA: 0 1
I tensorflow/core/common_runtime/gpu/gpu_device.cc:916] 0:   Y Y
I tensorflow/core/common_runtime/gpu/gpu_device.cc:916] 1:   Y Y
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX TITAN X, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:1) -> (device: 1, name: GeForce GTX TITAN X, pci bus id: 0000:02:00.0)
Traceback (most recent call last):
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 1022, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 1004, in _run_fn
    status, run_metadata)
  File "/usr/lib/python3.4/contextlib.py", line 66, in __exit__
    next(self.gen)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/errors_impl.py", line 469, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InternalError: Unable to get element from the feed as bytes.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 3, in <module>
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/training/saver.py", line 1439, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 965, in _run
    feed_dict_string, options, run_metadata)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 1015, in _do_run
    target_list, options, run_metadata)
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 1035, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Unable to get element from the feed as bytes.

1 个答案:

答案 0 :(得分:3)

我设法解决了这个问题并决定分享,以防将来有人遇到这种情况:

将所有占位符添加到集合中:

tf.add_to_collection('vars', input)
tf.add_to_collection('vars', labels)
tf.add_to_collection('vars', keep_prob)

独立合并和初始化变量(避免使用tf.global_variables_initializer()):

merged = tf.summary.merge([loss_summ, cost_summ, tloss_summ, acc_summ])

在训练期间保存模型:

if i%100 == 0:
    saver.save(sess, save_dir + 'model.ckpt', global_step=i+100)

初始化一个新的元图,包括在将元图导入新图之前的保护程序 会议:

这会阻止saver.saver_def.filename_tensor_name错误

  

名称&#39;保存/限制:0&#39;是指不存在的张量

这是因为:

* The default name scope for a tf.train.Saver is "save/" and the placeholder  
 is actually a tf.constant() whose name defaults to "Const:0", which explains  
 why the flag defaults to "save/Const:0".



saver = tf.train.Saver()
sess = tf.Session()
sess.run(init_op)

使用tf.train.get_checkpoint_state()获取检查点:

sess =tf.Session()
ckpt = tf.train.get_checkpoint_state(save_dir)
saver.restore(sess, ckpt.model_checkpoint_path)