为什么这个tensorflow循环需要这么多内存?

时间:2016-08-03 18:28:08

标签: memory-management tensorflow

我有一个复杂网络的设计版本:

import tensorflow as tf

a = tf.ones([1000])
b = tf.ones([1000])

for i in range(int(1e6)):
    a = a * b

我的直觉是这应该需要很少的记忆。只是初始数组分配的空间和一串利用节点的命令,并覆盖存储在张量“a”中的存储器。在每一步。但是内存使用量增长很快。

这里发生了什么,如何在计算张量并多次覆盖时减少内存使用量?

编辑:

感谢Yaroslav的建议,解决方案结果是使用了一个while_loop来最小化图表上的节点数量。这种方法效果很好,速度更快,内存更少,并且全部包含在图形中。

import tensorflow as tf

a = tf.ones([1000])
b = tf.ones([1000])

cond = lambda _i, _1, _2: tf.less(_i, int(1e6))
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b]

i = tf.constant(0)
output = tf.while_loop(cond, body, [i, a, b])

with tf.Session() as sess:
    result = sess.run(output)
    print(result)

1 个答案:

答案 0 :(得分:8)

您的a*b命令会转换为tf.mul(a, b),相当于tf.mul(a, b, g=tf.get_default_graph())。此命令将Mul节点添加到当前Graph对象,因此您尝试向当前图形添加100万个Mul节点。这也是有问题的,因为你不能序列化大于2GB的Graph对象,有些检查可能会在你处理这么大的图形时失败。

我建议MXNet人员阅读Programming Models for Deep Learning。 TensorFlow是"象征性的"用他们的术语编程,你就把它当作势在必行。

要使用Python循环获得所需内容,您可以构造乘法运算一次,并使用feed_dict重复运行,以提供更新

mul_op = a*b
result = sess.run(a)
for i in range(int(1e6)):
  result = sess.run(mul_op, feed_dict={a: result})

为了提高效率,您可以使用tf.Variable个对象和var.assign来避免Python< - > TensorFlow数据传输