我有一个复杂的网络,有许多重复的RNN步骤。编译需要很长时间(30多分钟,大部分都停留在渐变步骤),我发现this issue可能是相关的,它提到dynamic_rnn是一种更快的编译方式:
查看dynamic_rnn,然后重新格式化我的网络以包含一个while_loop,如下所示:
#input: tensor with 1000 time steps
def body(i, prev_state):
inp = tf.slice(input, i, 1)
new_state = cell(tf.squeeze(int), prev_state) # Includes scope reuse
return [tf.add(i, tf.constant(1)), new_state]
def cond(i):
return some_cond(i)
tf.while_loop(cond, body, [tf.constant(0), initial_state])
但这似乎没有帮助。除了简单地将单元调用放在循环中之外,还能使dynamic_rnn编译得更快吗?