TensorFlow while_loop占用了所有内存

时间:2019-06-13 19:49:28

标签: python tensorflow memory optimization while-loop

我想做一个while循环。不,我是说我需要一段时间的循环。

  1. 循环根据渐变更新数据
  2. 每个更新还取决于以前的更新。

我已尽力通过近似和批处理对循环进行部分矢量化。我无法避免使用循环。

不好的是我的while_loop耗尽了所有GPU内存。

我试图将分配op的定义移出内存...但这导致循环开始前只有一个分配。我想在循环的每次迭代中分配一个任务。

如何解决疯狂的内存消耗问题?

非常感谢。

测试代码:

import tensorflow as tf
import numpy as np
import time
from functools import partial


def precursor_condition(end_index, loop_index):
    return tf.less(loop_index, end_index)


def make_condition(end_index):
    return partial(precursor_condition, end_index)


def get_gradient(x):
    y = x * x
    return tf.gradients(y, x, stop_gradients=x)[0]


def precursor_loop_body(data, one, loop_index):
    data_prev = data[loop_index - 1]
    data_now = data[loop_index]
    assignment = tf.assign(data_now, get_gradient(data_now) + data_prev)
    with tf.control_dependencies([assignment]):
        return tf.add(loop_index, one)


def make_loop_body(data, one):
    return partial(precursor_loop_body, data, one)


def make_while_loop(data, dtype=np.int32):
    i = tf.get_variable('i', dtype=np.int32, initializer=1)
    one = tf.constant(1, dtype=dtype)
    end_ix = tf.constant(data.get_shape()[0], dtype=dtype)
    condition = make_condition(end_ix)
    body = make_loop_body(data, one)
    loop_vars = [i]
    return tf.while_loop(condition, body, loop_vars,
                         back_prop=False, parallel_iterations=1)


def main():
    print("TensorFlow version: {}".format(tf.__version__))
    count = int(1e9)
    initializer = tf.range(count, dtype=np.float32)
    data = tf.get_variable('data', dtype=np.float32,
                           trainable=False, initializer=initializer)

    print("initial data type", type(data))

    while_loop = make_while_loop(data)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.graph.finalize()
        sess.run(init)
        st = time.time()
        print("result:", sess.run([while_loop, data]))
        en = time.time()
        print("second per iteration", (en - st) / count)


main()

顺便说一句,tf.scan的内存消耗(I tried更加可怕。)

0 个答案:

没有答案