我正在实施Alex Graves的特定RNN单元包装器(Adaptive Computation Time)。作为我实现的一部分,我需要为用户提供几个张量,总结在每个时间步骤发生的一些计算(例如,在损失函数中使用的“思考成本”)。当然,在构建计算图时需要创建和返回这些张量。
如果static_rnn
用于创建RNN,这不是问题,因为我的包装器的call
方法被调用的次数与时间步数一样多,因此我可以从每个方法编译所有必要的信息时间步入单张量。但是,如果RNN是由dynamic_rnn
创建的,它使用tf.while_loop
,call
方法只调用一次进行初始化,因此我无法捕获单个时间步长结果
因此,我的问题是:有没有方法可以在调用方法的外部tf.while_loop
的所有迭代中携带张量,并对它们进行汇总(例如使用tf.stack
或{{ 1}})在图形构建期间?非常感谢!