评估期间内存不足,但培训工作正常

时间:2018-01-15 21:38:04

标签: pytorch

我最近将PyTorch从0.2升级到0.3。令人惊讶的是,我的旧程序在评估期间(在eval()模式下)会丢失内存错误但是训练工作正常。我使用相同的批量大小进行培训和评估。我完全不知道发生了什么事?有没有人面临类似的问题?有没有可能的解决方案?

我尝试在变量上使用volatile=True参数,但没有帮助。请注意,我没有做任何特殊的使用cuDNN。我使用的是默认设置。

def validate(self, dev_corpus):
    # Turn on evaluation mode which disables dropout.
    self.model.eval()

    dev_batches = helper.batchify(dev_corpus.data, self.config.batch_size)
    print('number of dev batches = ', len(dev_batches))

    dev_loss = 0
    num_batches = len(dev_batches)
    for batch_no in range(1, num_batches + 1):
        session_queries, session_query_length, rel_docs, rel_docs_length, doc_labels = helper.session_to_tensor(
            dev_batches[batch_no - 1], self.dictionary)
        if self.config.cuda:
            session_queries = session_queries.cuda()
            session_query_length = session_query_length.cuda()
            rel_docs = rel_docs.cuda()
            rel_docs_length = rel_docs_length.cuda()
            doc_labels = doc_labels.cuda()

        loss = self.model(session_queries, session_query_length, rel_docs, rel_docs_length, doc_labels)
        if loss.size(0) > 1:
            loss = loss.mean()
        dev_loss += loss.data[0]

    return dev_loss / num_batches

我正在使用上述功能进行评估。这里,session_queries,session_query_length,.... rest变量是通过启用volatile=True创建的。

请帮助!!

1 个答案:

答案 0 :(得分:2)

我认为它在验证期间会失败,因为volatile标志现在已被弃用,并且无效。从0.4.0开始,为避免在验证期间为所有变量计算梯度,应使用上下文管理器。代码示例:

with torch.no_grad():
    # Your validation code

因此,不会存储操作历史记录和渐变。这样可以节省内存。

有关更多详细信息,请参见0.4.0 migration guide

此外,您可以在完成评估后删除对这些变量的引用,如下所示:

someVar = Variable(someVar)
del someVar