批量处理数据以防止Out of Memory异常

时间:2017-02-09 12:45:36

标签: python tensorflow

在我的TensorFlow网络中,我计算

saveChanges

其中:

  • Y = (X * C) * R, X
  • (n, d)C
  • (d, L)X * C
  • (n, L)R
  • (L, d)Y

问题在于(n, d)n ≈ 1e6,因此L ≈ 1e5不适合内存,但是X * C,因此d = 4适合内存。我真的不需要存储Y

X * C可以分批计算,即

Y = (X * C) * R

有没有一种简洁的方法告诉TensorFlow这样做?或者这里的标准程序是什么?

1 个答案:

答案 0 :(得分:0)

以下似乎有效但创建变量会产生不良副作用!

关于如何改进的任何想法?

n, d, L = 40, 2, 3
num_batches = 10
rows_per_batch = n // num_batches

X = tf.constant(np.arange(n*d, dtype=np.float32).reshape(n, d))
C = tf.constant(np.arange(d*L, dtype=np.float32).reshape(d, L))
R = tf.constant(np.arange(L*d, dtype=np.float32).reshape(L, d))

Y = tf.matmul(tf.matmul(X, C), R)

yvar = tf.Variable(tf.constant(0, shape=Y.get_shape(), dtype=tf.float32))

def condition(current_batch):
    return current_batch < num_batches

def body(current_batch):
    current_rows_to_update = tf.range(rows_per_batch * current_batch,
                                      rows_per_batch * (current_batch + 1))
    xbatch = tf.gather(X, current_rows_to_update)
    ybatch = tf.matmul(tf.matmul(xbatch, C), R, name='matmul_da')

    yvar_update = tf.scatter_update(yvar, current_rows_to_update, ybatch)

    with tf.control_dependencies([yvar_update]):
        return current_batch + 1

loop = tf.while_loop(condition, body, loop_vars=[0])
with tf.control_dependencies([loop]):
    ybatched = tf.identity(yvar)