Retain_graph = True导致的自定义损失会使代码变慢

时间:2019-11-22 16:18:27

标签: pytorch loss-function

我有一个在Pytorch中实现的带有两个分支的模型。两个分支都共享嵌入层,但是它们分开了,并且根据不同的标签计算每个损失。最后,我将两个损失加起来,而后加累积损失。但是,如果没有retain_graph=True,则会引发错误,要求将retain_graph设置为True。但是,在保留图形时,后退步幅是如此之慢,以至于几乎无法进行训练。有什么解决办法吗?

0 个答案:

没有答案