如何在PyTorch中优化循环损耗函数的内存使用?

时间:2020-10-21 14:26:56

标签: python pytorch

我有一个损失函数,需要多次内部通过:

def my_loss_func(logits, sigma, labels, num_passes):
    total_loss = 0
    img_batch_size = logits.shape[0]
    logits_shape = list(logits.shape)
    for fpass in range(num_passes):
        noise_array = torch.normal(mean=0.0, std=1.0, size=logits_shape, device=torch.device('cuda:0'))
        stochastic_output = logits + sigma * noise_array
        del noise_array
        exponent_B = torch.log(torch.sum(torch.exp(stochastic_output), dim=-1, keepdim=True))
        inner_logits = exponent_B - stochastic_output
        soft_inner_logits = labels * inner_logits
        total_loss += torch.exp(soft_inner_logits)
        del exponent_B, inner_logits, soft_inner_logits
    mean_loss = total_loss / num_passes
    actual_loss = torch.mean(torch.log(mean_loss))
    return actual_loss

logit和sigma都是网络输出,因此具有关联的梯度。瓶颈(预期地)与行total_loss += torch.exp(soft_inner_logits)有关,因为afaik会为随后的遍历添加新的计算图。我已经读过,在循环中调用loss.backward()可以在类似情况下提供帮助,但是不幸的是,我需要在循环和背景之后基于此记录输出,因此该解决方案在这里似乎不可行。 >

更具体地说,当num_passes超过20时,我遇到了内存问题,还有其他方法可以完全优化内存分配以允许更多的通过吗?我根本不关心可读性/丑陋的解决方案,任何建议都会有很大帮助。

0 个答案:

没有答案