在Pytorch评估过程中耗尽内存

时间:2017-11-02 23:43:19

标签: pytorch

我在pytorch训练模型。每10个时代,我正在评估整列火车和测试数据集的列车和测试误差。由于某种原因,评估功能导致我的GPU内存不足。这很奇怪,因为我有相同的批量大小用于培训和评估。我相信这是由于net.forward()方法被重复调用并将所有隐藏值存储在内存中但是我不确定如何解决这个问题?

def evaluate(self, data):
    correct = 0
    total = 0
    loader = self.train_loader if data == "train" else self.test_loader
    for step, (story, question, answer) in enumerate(loader):
        story = Variable(story)
        question = Variable(question)
        answer = Variable(answer)
        _, answer = torch.max(answer, 1)

        if self.config.cuda:
            story = story.cuda()
            question = question.cuda()
            answer = answer.cuda()

        pred_prob = self.mem_n2n(story, question)[0]
        _, output_max_index = torch.max(pred_prob, 1)
        toadd = (answer == output_max_index).float().sum().data[0]
        correct = correct + toadd
        total = total + captions.size(0)

    acc = correct / total
    return acc

2 个答案:

答案 0 :(得分:12)

我认为在验证过程中失败了,因为您没有使用optimizer.zero_grad()。 zero_grad执行detach,使张量成为叶子。它通常用于训练部分的每个时期。

已删除在PyTorch 0.4.0变量中使用volatile标志。 参考 - migration_guide_to_0.4.0

从0.4.0开始,为避免在验证期间计算梯度,请使用torch.no_grad()

迁移指南中的代码示例。

# evaluate
with torch.no_grad():                   # operations inside don't track history
  for input, target in test_loader:
      ...

对于0.3.X,使用volatile应该可以工作。

答案 1 :(得分:10)

我建议对评估过程中使用的所有变量使用 volatile 标记设置为 True

    story = Variable(story, volatile=True)
    question = Variable(question, volatile=True)
    answer = Variable(answer, volatile=True)

因此,不存储渐变和操作历史记录,您将节省大量内存。 此外,您可以在批处理结束时删除对这些变量的引用:

del story, question, answer, pred_prob

不要忘记将模型设置为评估模式(完成评估后返回列车模式)。例如,像这样

model.eval()
相关问题