由于tf.scan
甚至给back_prop=False
都给了我oom,所以我尝试while_loop
并手工做长得很长的导数。
但是我的while_loop
很慢。
是否可以加快while_loop
的速度?
结果:
result: 1000000
second per iteration 5.064262390136719e-06
测试:
import tensorflow as tf
import numpy as np
import time
def make_while_loop(counts, dtype=np.int32):
i = tf.get_variable('i', dtype=np.int32, initializer=dtype(0))
one = tf.constant(1, dtype=dtype)
loop_ends = tf.constant(counts, dtype=dtype)
condition = lambda i: tf.less(i, loop_ends)
increment = lambda i: tf.add(i, one)
return tf.while_loop(condition, increment, [i], back_prop=False)
def main():
count = int(1e6)
i = make_while_loop(count)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.graph.finalize()
sess.run(init)
st = time.time()
print("result:", sess.run(i)[1])
en = time.time()
print("second per iteration", (en - st) / count)
main()