如何通过外部tf.while_loop(dynamic_rnn)的迭代来携带张量?

时间:2017-06-30 19:12:50

标签: tensorflow

我正在实施Alex Graves的特定RNN单元包装器(Adaptive Computation Time)。作为我实现的一部分,我需要为用户提供几个张量,总结在每个时间步骤发生的一些计算(例如,在损失函数中使用的“思考成本”)。当然,在构建计算图时需要创建和返回这些张量。

如果static_rnn用于创建RNN,这不是问题,因为我的包装器的call方法被调用的次数与时间步数一样多,因此我可以从每个方法编译所有必要的信息时间步入单张量。但是,如果RNN是由dynamic_rnn创建的,它使用tf.while_loopcall方法只调用一次进行初始化,因此我无法捕获单个时间步长结果

因此,我的问题是:有没有方法可以在调用方法的外部tf.while_loop的所有迭代中携带张量,并对它们进行汇总(例如使用tf.stack或{{ 1}})在图形构建期间?非常感谢!

0 个答案:

没有答案