示例代码:
import tensorflow as tf
x = tf.constant(1.0)
def cond_branch(cond):
y = tf.multiply(x, 2.0 if cond else 3.0)
print("cond branch %r" % cond)
print(y.op)
print(y.op.inputs[0].op)
return y
y = tf.cond(
tf.placeholder(tf.bool),
lambda: cond_branch(True),
lambda: cond_branch(False))
在切换上下文中,x
是一个外部张量。 “外部”表示它来自外部上下文。
您将获得类似以下的输出(仅适用于True
分支,并且已被剥离):
tf.multiply
操作:
name: "cond/Mul"
op: "Mul"
input: "cond/Mul/Switch:1"
input: "cond/Mul/y"
...
tf.multiply
的第一个输入,它是外部x
:
name: "cond/Mul/Switch"
op: "Switch"
input: "Const"
input: "cond/pred_id"
...
当您查看TF代码时,在x
中基本上用switch(x, pred)[branch]
替换了对CondContext._AddOpInternal
的直接引用。
为什么这里需要switch
?确保操作tf.multiply
仅在正确的分支中执行?为什么添加控制输入(op._add_control_input(self._pivot.op)
)还不够?基本上,CondContext._AddOpInternal
中的代码为何如此,为什么if not op.inputs
中的代码在所有情况下均不起作用?