了解tf.cond图

时间:2019-06-16 17:18:09

标签: tensorflow

我试图通过查看用于构建tf.cond的低级操作的图表来了解tf.cond的内部工作原理。

这是有问题的代码

x = tf.constant(1, name='x')
y = tf.constant(2, name='y')
z = tf.constant(3, name='z')

result = tf.cond(tf.less(x, y), lambda: tf.add(x, z), lambda: tf.square(y))

以及生成的graph

我想知道的是以下情况。

  • 导致switch_t和switch_f在那做的单独的开关操作是什么?
  • 为什么更少的输出连接到pred_id和Switch Ops?我希望它为tf.cond中的所有输入输出一个通往所有Switch-es的谓词(布尔量张量)。
  • 什么是pred_id?只是将谓词分为三个分支的身份吗?
  • 我试图了解tf.cond如何在运行时仅评估一个分支。

    我了解,当我们评估tf.cond的结果时,我们正在评估来自合并操作的张量。合并必须接受(我假设)四个张量作为输入(两个来自Add分支,两个来自Square分支),其中三个死了。但是,如果我们将张量沿图向下评估,我们只能知道张量的“失效”,不是吗?

  • 例如,对于Square分支,我知道switch将y张量和谓词作为输入,并在switch语句的相应分支上输出y和dead两个张量。但是该分支已经保证一定是针对错误的分支。那么,开关分支的T和F输出会发生什么呢?他们俩都进入Square Op?然后进入合并?

谢谢

S

0 个答案:

没有答案