我是Pytorch DstributedDataParallel()的新手,但是我发现大多数教程在训练过程中都保存了本地等级0 模型。这意味着,如果我得到3台机器,每台机器上都装有4个GPU,最后我将得到3台从每台机器上保存的模型。
例如在第252行的pytorch ImageNet教程中:
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
save_checkpoint({...})
如果rank % ngpus_per_node == 0
,他们将保存模型。
据我所知,DistributedDataParallel()将自动减少后端的损失,而无需做任何进一步的工作,每个进程都可以基于此自动同步损失。 每个流程上的所有模型在流程结束时只会稍有不同。这意味着我们只需要保存一个模型就足够了。
那么为什么我们不只是将模型保存在rank == 0
上,而是保存在rank % ngpus_per_node == 0
上呢?
如果我有多个模型,应该使用哪个模型?
如果这是在分布式学习中保存模型的正确方法,那么我应该合并它们,使用其中一个模型还是基于所有三个模型推断结果?
如果我错了,请告诉我。
答案 0 :(得分:3)
如果我在任何地方错了,请纠正我
您所指的更改是通过this commit在2018
中引入的,并描述为:
在多处理模式下,只有一个进程将写入检查点
以前,这些文件被保存为,没有任何if
,因此每个GPU上的每个节点都会保存一个确实浪费的模型,并且很可能会在每个节点上多次覆盖保存的模型。
现在,我们正在谈论分布式多处理(可能有许多工作人员,每个工作人员可能有多个GPU)。
每个进程的 args.rank
在this line脚本中被修改:
args.rank = args.rank * ngpus_per_node + gpu
具有以下注释:
对于多处理分布式培训,等级需要为 所有流程中的全球排名
因此args.rank
是所有节点之间所有GPU中的唯一ID (或看起来)。
如果是这样,并且每个节点都有ngpus_per_node
(在此训练代码中,假定每个节点都具有我收集到的GPU数量相同的GPU),则仅为一个(最后一个)GPU保存模型在每个节点上。在您使用3
机器和4
GPU的示例中,您将获得3
保存的模型(希望我能正确理解该代码,因为它非常复杂)。
如果您使用rank==0
,则仅会保存每个世界(其中世界定义为n_gpus * n_nodes
)的一个模型。
那我们为什么不只将模型保存在等级== 0,而是等级% ngpus_per_node == 0吗?
我将从您的假设开始,即:
据我所知,DistributedDataParallel()将自动 都可以减少后端的损失,而无需采取任何其他措施 作业,每个流程都可以基于此自动同步损失。
准确地说,它与损失无关,而是gradient
累积,并根据文档对权重进行了校正(强调我的观点):
此容器通过以下方式并行处理给定模块的应用程序: 通过在批次维度中分块,在指定设备上拆分输入。 模块在每台计算机上复制,并且 每个设备,并且每个这样的副本处理一部分输入。 在向后传递过程中,平均每个节点的梯度。
因此,当创建具有某些权重的模型时,它将在所有设备(每个节点的每个GPU)上复制。现在,每个GPU都获得一部分输入(例如,对于等于1024
,4
个节点且每个节点具有4
个GPU的总批大小,每个GPU将获得64
个元素),通过.backward()
张量方法计算前向通过,损耗并执行反向传播。现在,所有梯度均通过全聚集进行平均,在root
机器上优化了参数,并且参数已分配到所有节点,因此模块状态在所有机器上始终相同。
注意:我不确定这种平均是如何发生的(并且我没有看到它在文档中明确说明),尽管我认为这些平均首先是在GPU上平均,然后是在所有GPU上平均我认为这将是最高效的节点。
现在,为什么要在这种情况下为每个node
保存模型?原则上,您只能保存一个(因为所有模块都完全相同),但是它有一些缺点:
如果我有多个模型,应该使用哪个模型?
这没关系,因为通过优化器将相同的校正应用于具有相同初始权重的模型时,所有这些都将完全相同。
您可以使用类似的方法来加载保存的.pth
模型:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parallel_model = torch.nn.DataParallel(MyModelGoesHere())
parallel_model.load_state_dict(
torch.load("my_saved_model_state_dict.pth", map_location=str(device))
)
# DataParallel has model as an attribute
usable_model = parallel_model.model