TF op在switch上下文中,为什么在外部输入上需要`switch`

时间:2018-08-14 09:39:53

标签: tensorflow

示例代码:

import tensorflow as tf

x = tf.constant(1.0)

def cond_branch(cond):
  y = tf.multiply(x, 2.0 if cond else 3.0)
  print("cond branch %r" % cond)
  print(y.op)
  print(y.op.inputs[0].op)
  return y

y = tf.cond(
  tf.placeholder(tf.bool),
  lambda: cond_branch(True),
  lambda: cond_branch(False))

在切换上下文中,x是一个外部张量。 “外部”表示它来自外部上下文。

您将获得类似以下的输出(仅适用于True分支,并且已被剥离): tf.multiply操作:

name: "cond/Mul"
op: "Mul"
input: "cond/Mul/Switch:1"
input: "cond/Mul/y"
...

tf.multiply的第一个输入,它是外部x

name: "cond/Mul/Switch"
op: "Switch"
input: "Const"
input: "cond/pred_id"
...

当您查看TF代码时,在x中基本上用switch(x, pred)[branch]替换了对CondContext._AddOpInternal的直接引用。

为什么这里需要switch?确保操作tf.multiply仅在正确的分支中执行?为什么添加控制输入(op._add_control_input(self._pivot.op))还不够?基本上,CondContext._AddOpInternal中的代码为何如此,为什么if not op.inputs中的代码在所有情况下均不起作用?

0 个答案:

没有答案