例如,我想使用一些辅助损耗来提高模型性能。
哪种类型的代码可以在pytorch中实现?
#one
loss1.backward()
loss2.backward()
loss3.backward()
optimizer.step()
#two
loss1.backward()
optimizer.step()
loss2.backward()
optimizer.step()
loss3.backward()
optimizer.step()
#three
loss = loss1+loss2+loss3
loss.backward()
optimizer.step()
谢谢您的回答!
答案 0 :(得分:7)
第一次尝试和第三次尝试完全相同且正确,而第二次尝试则完全错误。
在Pytorch中,原因是低层渐变不会被随后的backward()
调用“覆盖”,而是被累积或求和。这使第一种方法和第三种方法完全相同,但是如果您使用低内存的GPU / RAM,则第一种方法可能更可取,因为具有立即backward() + step()
调用的1024个批处理大小与8个128和8个{ {1}}个呼叫,最后有一个backward()
呼叫。
为了说明这个想法,这是一个简单的示例。我们希望同时使张量step()
最接近x
:
[40,50,60]
现在是第一种方法:(我们使用x = torch.tensor([1.0],requires_grad=True)
loss1 = criterion(40,x)
loss2 = criterion(50,x)
loss3 = criterion(60,x)
来获取张量tensor.grad
的当前梯度)
x
这将输出:loss1.backward()
loss2.backward()
loss3.backward()
print(x.grad)
(编辑:将tensor([-294.])
放在前两个retain_graph=True
中,以获取更复杂的计算图)
第三种方法:
backward
同样的输出是:loss = loss1+loss2+loss3
loss.backward()
print(x.grad)
第二种方法有所不同,因为在调用tensor([-294.])
方法之后我们不调用opt.zero_grad
。这意味着在所有3个step()
调用中,使用了第一个step
调用的梯度。例如,如果3个损失为相同的重量提供了梯度backward
,而不是10(= 5 + 1 + 4),那么您的体重将具有5,1,4
作为梯度。
答案 1 :(得分:5)
-关于第一种方法的注释已删除,请参见其他答案-
您的第二种方法将要求您使用retain_graph=True
向后传播,这会导致大量的计算成本。此外,这是错误的,因为您将使用第一步优化程序来更新网络权重,然后您的下一个backward()
调用将在更新之前计算梯度,这意味着second step()
调用将在您的更新中插入噪音。另一方面,如果您执行了另一个forward()
调用以通过更新的权重进行反向传播,则最终将进行异步优化,因为第一层将使用第一个step()
进行一次更新,然后进行一次更新。随后的每个step()
调用都会获得更多的收益(本质上并没有错,但是效率低下,也许根本不是您想要的)。
答案 2 :(得分:0)
第一次和第三次尝试是正确的,但不相同。
如果初次尝试,它将多次计算/**
* Retrieve all of Dexie collection
*
* We can't just call Collection.toArray() because once the result is large
* enough, we'll get "Maximum IPC message size exceeded" error. This is a
* memory-friendly implementation. Although maybe a bit slow due to a page
* size of one.
*/
function retrieveDexieCollection(collection) {
return new Promise(async (resolve, reject) => {
try {
const result = []
await collection.each(r => {
result.push(r)
})
return resolve(result)
} catch (err) {
return reject(err)
}
})
}
// then later, use our function
const projectIds = [1,2,3]
const records = await retrieveDexieCollection(db
.whateverYourTableIsCalled
.where('projectId')
.anyOf(projectIds))
的梯度流,
但只有三次尝试。
与mappingFunction
梯度计算相同。
答案 3 :(得分:0)
如果您有两种不同的损失函数,请分别完成它们的正向转发,然后最后可以进行(loss1 + loss2).backward()
。效率更高,跳过了很多计算。
您要执行的代码:
loss_sum += loss.item()
以确保您不跟踪所有损失的历史记录。
item()
将破坏图形,从而使它可以从循环的一次迭代释放到下一次迭代。同样,您也可以使用detach()
。