计算tf.while_loop的每个时间步的渐变

时间:2018-03-29 11:34:58

标签: python tensorflow while-loop backpropagation

鉴于TensorFlow tf.while_loop,如何针对每个时间步长计算x_out相对于网络所有权重的梯度?

network_input = tf.placeholder(tf.float32, [None])
steps = tf.constant(0.0)

weight_0 = tf.Variable(1.0)
layer_1 = network_input * weight_0

def condition(steps, x):
    return steps <= 5

def loop(steps, x_in):
    weight_1 = tf.Variable(1.0)
    x_out = x_in * weight_1
    steps += 1
    return [steps, x_out]

_, x_final = tf.while_loop(
    condition,
    loop,
    [steps, layer_1]
)

一些笔记

  1. 在我的网络中,条件是动态的。不同的运行将以不同的次数运行while循环。
  2. 使用tf.gradients(x, tf.trainable_variables())呼叫AttributeError: 'WhileContext' object has no attribute 'pred'崩溃。似乎在循环中使用tf.gradients的唯一可能性是计算相对于weight_1的渐变和x_in /时间步的当前值,而不反向传播。
  3. 在每个时间步中,网络将输出概率分布而不是动作。然后,政策梯度实施需要渐变。

1 个答案:

答案 0 :(得分:4)

根据thisthis,您无法在Tensorflow中的tf.gradients内拨打tf.while_loop,我在尝试时遇到了困难的问题将共轭梯度下降完全创建到Tensorflow图中。

但是,如果我正确理解了您的模型,您可以创建自己的RNNCell版本并将其包装在tf.dynamic_rnn中,但是实际的单元格  实现将有点复杂,因为您需要在运行时动态评估条件。

对于初学者,您可以查看Tensorflow的dynamic_rnn代码here

或者,动态图表从来就不是Tensorflow强大的套件,因此请考虑使用其他框架,例如PyTorch,或者您可以试用eager_execution并查看是否有帮助。