我想为两个训练编写一个解码器(应该将渐变传递给编码器)和波束搜索模式(从python单步执行,遗憾的是,所以没有链接到编码器直接)。
理想情况下,这样的事情会起作用:
decoder(beamSearchFlag_boolPlaceholder, initalState_fromEncoder, initialState_placeholder, input):
initialState = tf.cond(beamSearchFlag_boolPlaceholder,
lambda: initialState_placeholder,
lambda: initalState_fromEncoder)
... = cell(input, initialState)
但是使用cond()TF仍然需要解析两个分支的依赖关系。当beamSearchFlag == False时,_fromEncoder分支被执行,即使没有效果,并且这是不必要图形的重要部分。有办法解决这个问题吗?