所以我的pytorch模型如下:
gathered_output = torch.nn.parallel.data_parallel(model, model_input, range(ngpu))
loss = F.loss(gathered_output, ground_truth)
loss.backward()
optimizer.step()
我正在做一些大的图像分割任务。 GPU 0使用的内存几乎增加了20%。有什么办法可以减轻这种情况?返回的数据不仅仅是cuda中的[8,128,300,400]浮点张量。我有4个GPU,每个GPU可以处理2个批处理。但是我想梯度也都被收集了吗?
如果我计算每个GPU上的损耗,并且只收回损耗,我会节省更多内存吗?请帮助