我正在尝试计算w.r.t. tf.while_loop中的循环变量,显然,这是不可能的。 玩具示例:
x = tf.Variable(10)
fx = tf.square(x) #possibly very complicated function
def cond(x):
return tf.not_equal(x, 0)
def body(x):
return tf.cond(tf.gradients(fx, x)[0] > 0, lambda: x-1, lambda: x+1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
o = sess.run(tf.while_loop(cond, body, [x]))
print(o)
运行上面的代码将产生结果:
TypeError: unorderable types: NoneType() > int()
原因(据我所知)是循环内的变量x不在全局函数fx中。
如何解决此问题?一个想法是将fx添加到循环变量中,但是如果没有在每个循环中完全重建计算图,我找不到一种维护它的方法。
谢谢!
更新: 进行以下更改可以解决该问题:
更改功能:
def f(x):
return tf.square(x) #possibly very complicated function
并更改正文:
def body(x):
return tf.cond(tf.gradients(f(x), x)[0] > 0, lambda: x-1, lambda: x+1)
但是此解决方案会在每次迭代中重建图,这可能效率不高。
答案 0 :(得分:0)
一个可行的解决方案是简单地使用带有控制依赖项的assign:
x = tf.Variable(10)
fx = tf.square(x) #possibly very complicated function
def cond(x_i):
return tf.not_equal(x_i, 0)
def body(x_i):
update_x = tf.assign(x,x_i)
with tf.control_dependencies([update_x]):
return tf.cond(tf.gradients(fx, x)[0] > 0, lambda: x_i-1, lambda: x_i+1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
o = sess.run(tf.while_loop(cond, body, [x]))
print(o)