我正在进行编码器 - 解码器设置。我希望能够运行一次编码器,然后执行多个解码器运行。我提出的解决方案是向解码器提供TF条件节点(使用tf.where),该节点包含编码器的最终隐藏状态(在这种情况下,当我要求时,TF将运行编码器解码器输出),或带有编码器存储结果的占位符(理论上TF不需要运行编码器)。
以下是代码的相关部分:
encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])
由于我没有从这种方法中获得加速,我很确定它不起作用,并且tf.where的两个分支每次都由TF运行,即使它只需要从占位符读取。
有没有办法使用tf.where这样它不会运行编码器?我已经查看了方法的描述,并且我不确定两个分支是否总是被计算,我在这个问题上看到了相互矛盾的信息。
谢谢!
答案 0 :(得分:0)
当您想要推迟执行其中一个分支直到评估谓词时,可以使用tf.cond()
函数。
encoder_state = tf.cond(
tf.greater_equal(branching_points, 0),
lambda: encoder_state,
lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])
答案 1 :(得分:0)
我尝试使用tf.cond创建模型并提供字典,但是tf.cond只接受一个输入,因此,如果您有多个branching_points,将无法使用。
我已经创建了变通方法,但是它非常复杂,我希望看到一个更好的解决方案,特别是如果true_fn和false_fn在计算上很昂贵,那么它仅会提高性能。
如果在未选择分支的情况下不应执行true_fn或false_fn(例如,如果您在这些函数中使用tf.assign),则该解决方案也很有用
首先,我创建布尔张量:
branch_1 = tf.greater_equal(branching_points, 0)
branch_2 = tf.logical_not(branch_1)
然后我使用布尔掩码仅从分支执行True条件
result_1 = tf.boolean_mask(branch_1)
result_2 = tf.boolean_mask(branch_2)
最后,如果需要,您可以形成一个张量。
如果顺序很重要,则可以使用tf.where(tf.equal(branch_1,True))
和tf.where(tf.equal(branch_2,True))
分别获取result_1和result_2的索引。然后,您应用tf.scatter_nd。
如果顺序无关紧要,您只需使用tf.concat