关于循环变量的tf.gradient

时间:2018-07-14 12:55:09

标签: python tensorflow

我正在尝试计算w.r.t. tf.while_loop中的循环变量,显然,这是不可能的。 玩具示例:

x = tf.Variable(10)
fx = tf.square(x) #possibly very complicated function

def cond(x):
    return tf.not_equal(x, 0)

def body(x):
    return  tf.cond(tf.gradients(fx, x)[0] > 0, lambda: x-1, lambda: x+1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o = sess.run(tf.while_loop(cond, body, [x]))
    print(o)

运行上面的代码将产生结果:

TypeError: unorderable types: NoneType() > int()

原因(据我所知)是循环内的变量x不在全局函数fx中。

如何解决此问题?一个想法是将fx添加到循环变量中,但是如果没有在每个循环中完全重建计算图,我找不到一种维护它的方法。

谢谢!

更新: 进行以下更改可以解决该问题:

更改功能:

def f(x):
    return tf.square(x) #possibly very complicated function

并更改正文:

def body(x):
    return  tf.cond(tf.gradients(f(x), x)[0] > 0, lambda: x-1, lambda: x+1)

但是此解决方案会在每次迭代中重建图,这可能效率不高。

1 个答案:

答案 0 :(得分:0)

一个可行的解决方案是简单地使用带有控制依赖项的assign:

x = tf.Variable(10)
fx = tf.square(x) #possibly very complicated function

def cond(x_i):
    return tf.not_equal(x_i, 0)

def body(x_i):
    update_x = tf.assign(x,x_i)
    with tf.control_dependencies([update_x]):
        return  tf.cond(tf.gradients(fx, x)[0] > 0, lambda: x_i-1, lambda: x_i+1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o = sess.run(tf.while_loop(cond, body, [x]))
    print(o)