tf.gradients()不能与'tf.assign()'一起使用,但可以使用'='

时间:2018-01-13 17:48:31

标签: python tensorflow graph autograd

在下面的简单代码中,渐变得到正确计算。

import tensorflow as tf

x = tf.constant([1, 2, 3, 4], dtype=tf.float32)
y = tf.Variable(tf.ones_like(x), dtype=tf.float32)

y = 2*x
grad = tf.gradients(y, x)

ini = tf.global_variables_initializer()


with tf.Session() as ses:
    ses.run(ini)
    print(ses.run(grad))

结果如预期的那样是[array([ 2., 2., 2., 2.], dtype=float32)]。 尝试使用tf.assign进行函数计算时遇到问题。以下代码:

import tensorflow as tf

x = tf.constant([1, 2, 3, 4], dtype=tf.float32)
y = tf.Variable(tf.ones_like(x), dtype=tf.float32)

func = tf.assign(y, 2*x)
grad = tf.gradients(y, x)

ini = tf.global_variables_initializer()

with tf.Session() as ses:
    ses.run(ini)
    ses.run(func)
    print(ses.run(grad))

...产生错误:

  

TypeError:Fetch参数None具有无效类型<class 'NoneType'>

为什么会这样? xy节点之间的连接是否通过tf.assign操作以某种方式“丢失”了?

1 个答案:

答案 0 :(得分:1)

在第二个示例中,<script src="https://threejs.org/build/three.min.js"></script> <script src="https://threejs.org/examples/js/controls/OrbitControls.js"></script> <div id="data"> phi: <span id="phi"></span>&deg;<br> theta: <span id="theta"></span>&deg; </div>x之间没有依赖关系。 y依赖于并且恰好修改func的操作。如果您检查相应的y操作,您会看到:

tf.assign

但是op: "Assign" input: "Variable" # this is y input: "mul" # this is 2*x x是独立的,这就是引擎无法采用渐变的原因。