如何在PyTorch的计算图中将张量附加到特定点?

时间:2019-03-15 22:25:30

标签: python pytorch autograd computation-graph

正如问题所述,我需要在Pytorch的计算图中的特定点上附加张量。

我要这样做的是: 从所有迷你批次获取输出时,将它们累加到列表中,当一个时期结束时,计算平均值。然后,我需要根据均值计算损耗,因此反向传播必须考虑所有这些操作。

当训练数据不多时(无需分离和存储),我能够做到这一点。但是,当它变大时,这是不可能的。如果我没有每次都分离输出张量,那么我的GPU内存就用光了;如果我分离,我将从计算图中失去输出张量的轨迹。不管我有多少个GPU,这看起来都是不可能的,因为即使分配了4个以上的GPU,PyTorch也会在将它们保存到列表中之前不分离就只使用前4个来存储输出张量。

我们非常感谢您的帮助。

谢谢。

0 个答案:

没有答案