我有一种异常类型的循环网络,该网络在每个时间步都会更新从状态张量收集的少量状态向量。
在我当前的实现中,我使用dynamic_partition和dynamic_stitch更新这些状态,以构造一个与原始大小相同的新状态张量,并交换了更新后的状态。之所以选择第一遍此实现是因为每个状态更新可能取决于一个或多个状态,这些状态可能在过去的任何时间步已更新。此实现简化了这些更新的收集操作。
不幸的是,我很确定这会导致内存需求随着重复深度的增加而增加。在整个部署期间,我只需要更新一次每个状态,因此我希望内存需求在重复深度方面保持不变。
解决此内存需求的最佳方法是什么?我当时在看TensorArray在每个时间步仅存储更新的状态,但是我将不得不以某种方式协调读取和收集步骤,以获取执行更新所需的输入。我认为这将使效率非常低下,例如tf.map_fn。
有没有一种方法可以将内存有效的稀疏更新应用于状态张量?状态张量的第一维可以从一个小批量到下一小批量。
感谢您的帮助!