Tensorflow GraphDef不能大于2GB,即使它不是

时间:2017-02-24 14:06:11

标签: tensorflow

当我尝试保存并恢复模型时,我收到了此错误。

我玩过并改变了结构,足以将模型的大小减小到1.2GB,但错误仍然存​​在:

-rw-rw-r-- 1 ubuntu ubuntu 1.2G Feb 24 13:44 cnn-classifier-model-0.data-00000-of-00001
-rw-rw-r-- 1 ubuntu ubuntu 1.1K Feb 24 13:44 cnn-classifier-model-0.index
-rw-rw-r-- 1 ubuntu ubuntu 102K Feb 24 13:44 cnn-classifier-model-0.meta

知道为什么会这样吗?

一些代码:

# In the graph
# Create a saver.
saver = tf.train.Saver()
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('foo', foo)

# In session
save_path = saver.save(sess, 'metafiles/cnn-classifier-model', global_step=0)    # at the end of training

# Later in a seperate Jupyter cell/new notebook
sess = tf.InteractiveSession()
new_saver = tf.train.import_meta_graph('metafiles/cnn-classifier-model-0.meta')
new_saver.restore(sess, 'metafiles/cnn-classifier-model-0')
foo = tf.get_collection('foo')[0]
sess.close()

错误跟踪:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-49-ada5c83f8e85> in <module>()
      4 sess = tf.InteractiveSession()
      5 new_saver = tf.train.import_meta_graph('metafiles/cnn-classifier-model-0.meta')
----> 6 new_saver.restore(sess, 'metafiles/cnn-classifier-model-0')
      7 predictions = tf.get_collection('predictions')[0]
      8 sess.close()

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in restore(self, sess, save_path)
   1437       return
   1438     sess.run(self.saver_def.restore_op_name,
-> 1439              {self.saver_def.filename_tensor_name: save_path})
   1440 
   1441   @staticmethod

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
    998                 run_metadata):
    999       # Ensure any changes to the graph are reflected in the runtime.
-> 1000       self._extend_graph()
   1001       with errors.raise_exception_on_not_ok_status() as status:
   1002         return tf_session.TF_Run(session, options,

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _extend_graph(self)
   1042         graph_def, self._current_version = self._graph._as_graph_def(
   1043             from_version=self._current_version,
-> 1044             add_shapes=self._add_shapes)
   1045         # pylint: enable=protected-access
   1046 

/home/ubuntu/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in _as_graph_def(self, from_version, add_shapes)
   2218           bytesize += op.node_def.ByteSize()
   2219           if bytesize >= (1 << 31) or bytesize < 0:
-> 2220             raise ValueError("GraphDef cannot be larger than 2GB.")
   2221       if self._functions:
   2222         for f in self._functions.values():

ValueError: GraphDef cannot be larger than 2GB.

0 个答案:

没有答案