我有一个损失函数,需要多次内部通过:
def my_loss_func(logits, sigma, labels, num_passes):
total_loss = 0
img_batch_size = logits.shape[0]
logits_shape = list(logits.shape)
for fpass in range(num_passes):
noise_array = torch.normal(mean=0.0, std=1.0, size=logits_shape, device=torch.device('cuda:0'))
stochastic_output = logits + sigma * noise_array
del noise_array
exponent_B = torch.log(torch.sum(torch.exp(stochastic_output), dim=-1, keepdim=True))
inner_logits = exponent_B - stochastic_output
soft_inner_logits = labels * inner_logits
total_loss += torch.exp(soft_inner_logits)
del exponent_B, inner_logits, soft_inner_logits
mean_loss = total_loss / num_passes
actual_loss = torch.mean(torch.log(mean_loss))
return actual_loss
logit和sigma都是网络输出,因此具有关联的梯度。瓶颈(预期地)与行total_loss += torch.exp(soft_inner_logits)
有关,因为afaik会为随后的遍历添加新的计算图。我已经读过,在循环中调用loss.backward()可以在类似情况下提供帮助,但是不幸的是,我需要在循环和背景之后基于此记录输出,因此该解决方案在这里似乎不可行。 >
更具体地说,当num_passes超过20时,我遇到了内存问题,还有其他方法可以完全优化内存分配以允许更多的通过吗?我根本不关心可读性/丑陋的解决方案,任何建议都会有很大帮助。