Tensorflow:可以阻止执行某个分支吗?

时间:2017-06-15 13:46:32

标签: tensorflow

我正在进行编码器 - 解码器设置。我希望能够运行一次编码器,然后执行多个解码器运行。我提出的解决方案是向解码器提供TF条件节点(使用tf.where),该节点包含编码器的最终隐藏状态(在这种情况下,当我要求时,TF将运行编码器解码器输出),或带有编码器存储结果的占位符(理论上TF不需要运行编码器)。

以下是代码的相关部分:

encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
                         rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

由于我没有从这种方法中获得加速,我很确定它不起作用,并且tf.where的两个分支每次都由TF运行,即使它只需要从占位符读取。

有没有办法使用tf.where这样它不会运行编码器?我已经查看了方法的描述,并且我不确定两个分支是否总是被计算,我在这个问题上看到了相互矛盾的信息。

谢谢!

2 个答案:

答案 0 :(得分:0)

当您想要推迟执行其中一个分支直到评估谓词时,可以使用tf.cond()函数。

encoder_state = tf.cond(
    tf.greater_equal(branching_points, 0),
    lambda: encoder_state,
    lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

答案 1 :(得分:0)

我尝试使用tf.cond创建模型并提供字典,但是tf.cond只接受一个输入,因此,如果您有多个branching_points,将无法使用。
我已经创建了变通方法,但是它非常复杂,我希望看到一个更好的解决方案,特别是如果true_fn和false_fn在计算上很昂贵,那么它仅会提高性能。 如果在未选择分支的情况下不应执行true_fn或false_fn(例如,如果您在这些函数中使用tf.assign),则该解决方案也很有用

首先,我创建布尔张量:

branch_1 = tf.greater_equal(branching_points, 0)
branch_2 = tf.logical_not(branch_1)

然后我使用布尔掩码仅从分支执行True条件

result_1 = tf.boolean_mask(branch_1)
result_2 = tf.boolean_mask(branch_2)

最后,如果需要,您可以形成一个张量。 如果顺序很重要,则可以使用tf.where(tf.equal(branch_1,True))tf.where(tf.equal(branch_2,True))分别获取result_1和result_2的索引。然后,您应用tf.scatter_nd。 如果顺序无关紧要,您只需使用tf.concat