我试图通过查看用于构建tf.cond的低级操作的图表来了解tf.cond的内部工作原理。
这是有问题的代码
x = tf.constant(1, name='x')
y = tf.constant(2, name='y')
z = tf.constant(3, name='z')
result = tf.cond(tf.less(x, y), lambda: tf.add(x, z), lambda: tf.square(y))
以及生成的graph。
我想知道的是以下情况。
我试图了解tf.cond如何在运行时仅评估一个分支。
我了解,当我们评估tf.cond的结果时,我们正在评估来自合并操作的张量。合并必须接受(我假设)四个张量作为输入(两个来自Add分支,两个来自Square分支),其中三个死了。但是,如果我们将张量沿图向下评估,我们只能知道张量的“失效”,不是吗?
例如,对于Square分支,我知道switch将y张量和谓词作为输入,并在switch语句的相应分支上输出y和dead两个张量。但是该分支已经保证一定是针对错误的分支。那么,开关分支的T和F输出会发生什么呢?他们俩都进入Square Op?然后进入合并?
谢谢
S