我想做一个while循环。不,我是说我需要一段时间的循环。
我已尽力通过近似和批处理对循环进行部分矢量化。我无法避免使用循环。
不好的是我的while_loop
耗尽了所有GPU内存。
我试图将分配op的定义移出内存...但这导致循环开始前只有一个分配。我想在循环的每次迭代中分配一个任务。
如何解决疯狂的内存消耗问题?
非常感谢。
测试代码:
import tensorflow as tf
import numpy as np
import time
from functools import partial
def precursor_condition(end_index, loop_index):
return tf.less(loop_index, end_index)
def make_condition(end_index):
return partial(precursor_condition, end_index)
def get_gradient(x):
y = x * x
return tf.gradients(y, x, stop_gradients=x)[0]
def precursor_loop_body(data, one, loop_index):
data_prev = data[loop_index - 1]
data_now = data[loop_index]
assignment = tf.assign(data_now, get_gradient(data_now) + data_prev)
with tf.control_dependencies([assignment]):
return tf.add(loop_index, one)
def make_loop_body(data, one):
return partial(precursor_loop_body, data, one)
def make_while_loop(data, dtype=np.int32):
i = tf.get_variable('i', dtype=np.int32, initializer=1)
one = tf.constant(1, dtype=dtype)
end_ix = tf.constant(data.get_shape()[0], dtype=dtype)
condition = make_condition(end_ix)
body = make_loop_body(data, one)
loop_vars = [i]
return tf.while_loop(condition, body, loop_vars,
back_prop=False, parallel_iterations=1)
def main():
print("TensorFlow version: {}".format(tf.__version__))
count = int(1e9)
initializer = tf.range(count, dtype=np.float32)
data = tf.get_variable('data', dtype=np.float32,
trainable=False, initializer=initializer)
print("initial data type", type(data))
while_loop = make_while_loop(data)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.graph.finalize()
sess.run(init)
st = time.time()
print("result:", sess.run([while_loop, data]))
en = time.time()
print("second per iteration", (en - st) / count)
main()
顺便说一句,tf.scan
的内存消耗(I tried更加可怕。)