请原谅我一个简单的问题。
我想实现条件分支,如:
if need_backprop != 0:
cross_entropy = ......
if need_backprop == 0:
tf.stop_gradient(cross_entropy)
我发现“ if”语句无效。所以我想知道是否有任何方法可以实现条件分支。
谢谢!
答案 0 :(得分:2)
我认为 tf.cond: https://www.tensorflow.org/api_docs/python/tf/cond是您想要的。
例如,您的代码可能是这样的,
def case1():
return cross_entropy = ......
def case2():
return tf.stop_gradient(cross_entropy)
result = tf.cond(need_backprop != 0, lambda: case1(), lambda: case2())
希望这会有所帮助。