在推理期间,当我们通过网络传播时,我们不需要保留前一层的激活。但是,由于我们没有明确告诉程序丢弃它们,因此它不区分训练和推理过程。有没有办法 - 也许是一个简单的标志,类,方法 - 在Tensorflow中进行这种内存管理?只需使用tf.stop_gradient
工作吗?
答案 0 :(得分:3)
最简单的方法是使用freeze脚本“freeze_graph.py
”(tensorflow的术语)模型。
此脚本基本上删除了所有不必要的操作,并且还用常量替换所有变量,然后将生成的图形导出回磁盘。
为此,您需要在图表中指定您在推理期间使用的输出。无法到达输出的节点(可能是摘要,损失,渐变等)会自动丢弃。
一旦消除了向后传递,tensorflow可以优化其内存使用量,特别是自动free or reuse memory taken by unused nodes。