tensorflow:ValueError:GraphDef不能大于2GB

时间:2017-04-16 22:58:57

标签: tensorflow neural-network conv-neural-network convolution

这是我得到的错误

Traceback (most recent call last):
  File "fully_connected_feed.py", line 387, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/home/-/.local/lib/python2.7/site-
packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "fully_connected_feed.py", line 289, in main
    run_training()
  File "fully_connected_feed.py", line 256, in run_training
    saver.save(sess, checkpoint_file, global_step=step)
  File "/home/-/.local/lib/python2.7/site-
packages/tensorflow/python/training/saver.py", line 1386, in save
    self.export_meta_graph(meta_graph_filename)
  File "/home/-/.local/lib/python2.7/site-
packages/tensorflow/python/training/saver.py", line 1414, in export_meta_graph
    graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
  File "/home/-/.local/lib/python2.7/site-
packages/tensorflow/python/framework/ops.py", line 2257, in as_graph_def
    result, _ = self._as_graph_def(from_version, add_shapes)
  File "/home/-/.local/lib/python2.7/site-
packages/tensorflow/python/framework/ops.py", line 2220, in _as_graph_def
    raise ValueError("GraphDef cannot be larger than 2GB.")
ValueError: GraphDef cannot be larger than 2GB.

我相信这是来自此代码的结果

weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden1")[0]
weights = tf.scatter_nd_update(weights,indices, updates)
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden2")[0]
weights = tf.scatter_nd_update(weights,indices, updates)

我不确定为什么我的模型尺寸如此之大(15k步和240MB)。有什么想法吗?谢谢!

1 个答案:

答案 0 :(得分:1)

很难说没有看到代码就会发生什么,但总的来说TensorFlow模型的大小不会随着步数的增加而增加 - 它们应该被修复。

如果模型大小随着步数的增加而增加,则表明每一步都会添加计算图。例如,像:

import tensorflow as tf

with tf.Session() as sess:
  for i in xrange(1000):
    sess.run(tf.add(1, 2))
    # or perhaps sess.run(tf.scatter_nd_update(...)) in your case

将在图中创建3000个节点(一个用于添加,一个用于'1',一个用于'2',每次迭代)。相反,您希望定义一次计算图并重复运行,例如:

import tensorflow as tf

x = tf.add(1, 2)
# or perhaps x = tf.scatter_nd_update(...) in your case
with tf.Session() as sess:
  for i in xrange(1000):
    sess.run(x)

对于所有1000个(以及任何更多)迭代,其将具有3个节点的固定图。希望有所帮助。