如何在张量流中实现条件分支?

时间:2018-09-27 03:04:01

标签: python tensorflow

请原谅我一个简单的问题。

我想实现条件分支,如:

if need_backprop != 0:
    cross_entropy = ......

if need_backprop == 0:
    tf.stop_gradient(cross_entropy)

我发现“ if”语句无效。所以我想知道是否有任何方法可以实现条件分支。

谢谢!

1 个答案:

答案 0 :(得分:2)

我认为 tf.cond: https://www.tensorflow.org/api_docs/python/tf/cond是您想要的。

例如,您的代码可能是这样的,

def case1():
       return cross_entropy = ......

def case2():
       return tf.stop_gradient(cross_entropy)

result = tf.cond(need_backprop != 0, lambda: case1(), lambda: case2())

希望这会有所帮助。