如何计算多个图像的丢失,然后反向传播平均损失并更新网络权重

时间:2017-11-05 09:40:34

标签: backpropagation pytorch loss

我正在做一个批量大小为1的任务,即每个批次只包含1个图像。所以我必须进行手动批处理:当累计损失的数量达到一个数字时,平均损失然后进行反向传播。 我原来的代码是:

ave_loss.backward()

此代码将提供错误消息

  

RuntimeError:尝试第二次向后遍历图形,但缓冲区已被释放。第一次向后调用时指定retain_graph = True。

我尝试过以下方式,

第一种方式(失败)

我阅读了一些关于此错误消息的帖子,但无法完全理解。将ave_loss.backward(retain_graph=True)更改为nan会阻止出现错误消息,但损失不会改善即将变为total_loss = total_loss + loss.data[0]

第二种方式(失败)

我还尝试更改real_batchsize,这也会阻止错误消息。但损失总是一样的。所以一定有问题。

第三种方式(成功)

按照this post中的说明,对于每个图像的丢失,我将损失除以real_batchsize并将其支持。当输入图像的数量达到optimizer.step()时,我使用real_batchsize进行一次参数更新。随着训练过程的进行,损失正在缓慢减少。但是训练速度非常慢,因为我们为每个图像提供反向支持。

我的问题

错误消息在我的案例中意味着什么?另外,为什么第一种方式和第二种方式不起作用?如何正确编写代码,以便我们可以对每个fun SharedPreferences.Editor.putIntArray(key: String, value: IntArray): SharedPreferences.Editor { return putString(key, value.joinToString( separator = ",", transform = { it.toString() })) } fun SharedPreferences.getIntArray(key: String): IntArray { with(getString(key, "")) { with(if(isNotEmpty()) split(',') else return intArrayOf()) { return IntArray(count(), { this[it].toInt() }) } } } 图像进行反向渐变,并更新渐变一次,以便提高训练速度?我知道我的代码几乎是正确的,但我只是不知道如何改变它。

1 个答案:

答案 0 :(得分:3)

您在这里遇到的问题与PyTorch如何在不同的传递上累积渐变有关。 (有关类似问题的其他帖子,请参阅here) 因此,让我们看看当您拥有以下形式的代码时会发生什么:

loss_total = Variable(torch.zeros(1).cuda(), requires_grad=True)
for l in (loss_func(x1,y1), loss_func(x2, y2), loss_func(x3, y3), loss_func(x4, y4)):
    loss_total = loss_total + l
    loss_total.backward()

这里,当loss_total在不同的迭代中具有以下值时,我们会进行向后传递:

total_loss = loss(x1, y1)
total_loss = loss(x1, y1) + loss(x2, y2)
total_loss = loss(x1, y1) + loss(x2, y2) + loss(x3, y3)
total_loss = loss(x1, y1) + loss(x2, y2) + loss(x3, y3) + loss(x4, y4)

因此,当您每次.backward()total_loss时,您实际上会在.backward() 上四次致电loss(x1, y1)(和{{ 1}}三次,等等。

将其与其他帖子中讨论的内容相结合,即为了优化内存使用,PyTorch将在调用loss(x2, y2)时释放附加到变量的图形(从而破坏连接.backward()的渐变x1y1x2等等,您可以看到错误消息的含义 - 您尝试多次向后传递丢失,但基础图已被释放在第一次通过后。 (当然,除非指定y2

至于您尝试过的具体变化: 第一种方式:在这里,你将永远积累(即总结 - 再次,见另一篇文章)渐变,与它们(可能)加起来retain_graph=True。 第二种方式:在这里,您通过执行inf,删除loss包装器,从而删除渐变信息(因为只有变量保持渐变),将loss.data转换为张量。 第三种方式:在这里,你只需要通过每个Variable元组,因为你立即做了一个backprop步骤,完全避免了上述问题。

解决方案:我没有测试过,但是从我收集的内容来看,解决方案应该非常简单:在每个批处理的开头创建一个新的xk, yk对象,然后将所有损失加总到该对象中,然后在最后执行一个最后的backprop步骤。