Tensorflow中的反向传播(通过时间)代码

时间:2016-04-20 13:23:59

标签: python tensorflow backpropagation

在哪里可以找到Tensorflow(python API)中的反向传播(通过时间)代码?或者是否使用其他算法?

例如,当我创建LSTM网时。

2 个答案:

答案 0 :(得分:8)

TensorFlow中的所有反向传播都是通过自动区分网络正向传递中的操作,并添加显式操作来计算网络中每个点的梯度来实现的。可以在tf.gradients()中找到常规实现,但使用的特定版本取决于LSTM的实现方式:

  • 如果将LSTM实现为有限数量时间步的展开循环,则通常的方法是截断反向传播,使用tf.gradients()中的算法在相反方向上构建展开的反向传播循环。 / LI>
  • 如果将LSTM实施为tf.while_loop(),则会使用其他支持来区分control_flow_grad.py中的循环。

答案 1 :(得分:0)

我不确定,但这可能有效:

由于RNN可​​以像前馈网一样进行训练,因此代码非常相似。 这就是你训练前馈网的方法:( X是输入)

train = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)

# Session
sess = tf.Session()
sess.run(tf.initialize_all_variables())

for i in range(epochs):
    sess.run(train, feed_dict={X: [[0, 0, 1], [1, 1, 1], [1, 0, 1], [0, 1, 1]], labels: [[0], [1], [1], [0]]})

时间反向传播的唯一区别是现在每个时代都有一个嵌套的时间循环

这是训练一个简单的rnn的代码:

train = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)

time_series = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
for i in range(number_of_epochs):
    for j in range(len(time_series) - 1):
        curr_X = time_series[j+1]
        curr_prev = time_series[j]
        lbs = curr_prev
        sess.run(train, feed_dict={X: [[curr_X]], prev_val: [[curr_prev]], labels: [[lbs]]})

在此代码中,rnn学习了替代1和0的时间序列。