我有一个复杂网络的设计版本:
import tensorflow as tf
a = tf.ones([1000])
b = tf.ones([1000])
for i in range(int(1e6)):
a = a * b
我的直觉是这应该需要很少的记忆。只是初始数组分配的空间和一串利用节点的命令,并覆盖存储在张量“a”中的存储器。在每一步。但是内存使用量增长很快。
这里发生了什么,如何在计算张量并多次覆盖时减少内存使用量?
编辑:
感谢Yaroslav的建议,解决方案结果是使用了一个while_loop来最小化图表上的节点数量。这种方法效果很好,速度更快,内存更少,并且全部包含在图形中。
import tensorflow as tf
a = tf.ones([1000])
b = tf.ones([1000])
cond = lambda _i, _1, _2: tf.less(_i, int(1e6))
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b]
i = tf.constant(0)
output = tf.while_loop(cond, body, [i, a, b])
with tf.Session() as sess:
result = sess.run(output)
print(result)
答案 0 :(得分:8)
您的a*b
命令会转换为tf.mul(a, b)
,相当于tf.mul(a, b, g=tf.get_default_graph())
。此命令将Mul
节点添加到当前Graph
对象,因此您尝试向当前图形添加100万个Mul
节点。这也是有问题的,因为你不能序列化大于2GB的Graph对象,有些检查可能会在你处理这么大的图形时失败。
我建议MXNet人员阅读Programming Models for Deep Learning。 TensorFlow是"象征性的"用他们的术语编程,你就把它当作势在必行。
要使用Python循环获得所需内容,您可以构造乘法运算一次,并使用feed_dict
重复运行,以提供更新
mul_op = a*b
result = sess.run(a)
for i in range(int(1e6)):
result = sess.run(mul_op, feed_dict={a: result})
为了提高效率,您可以使用tf.Variable
个对象和var.assign
来避免Python< - > TensorFlow数据传输