TensorFlow中的资格跟踪

时间:2017-06-06 03:59:26

标签: tensorflow gradient-descent reinforcement-learning

根据Sutton的书 - 强化学习:简介,网络权重的更新方程式由下式给出:

theta = theta + alpha * delta * e

其中e t 是资格跟踪。 这类似于具有额外e t 的梯度下降更新 此资格跟踪是否可以包含在TensorFlow的tf.train.GradientDescentOptimizer中?

1 个答案:

答案 0 :(得分:2)

这是使用tf.contrib.layers.scale_gradient进行渐变元素乘法的简单示例。在前向传递中,它只是一个身份操作,在后向传递中,它将渐变乘以第二个参数。

import tensorflow as tf

with tf.Graph().as_default():
  some_value = tf.constant([0.,0.,0.])
  scaled = tf.contrib.layers.scale_gradient(some_value, [0.1, 0.2, 0.3])
  (some_value_gradient,) = tf.gradients(tf.reduce_sum(scaled), some_value)
  with tf.Session():
    print(scaled.eval())
    print(some_value_gradient.eval())

打印:

[ 0.  0.  0.]
[ 0.1         0.2         0.30000001]