我使用DistributedDataParallel训练了模型,并制作了pth文件
if args.gpu is not None:
print('Gpu setting...',args.gpu)
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu]
#,output_device=[args.gpu]
,find_unused_parameters=True)
然后我尝试评估模型
self.model = EfficientDet(num_classes=num_class,
network=network,
W_bifpn=EFFICIENTDET[network]['W_bifpn'],
D_bifpn=EFFICIENTDET[network]['D_bifpn'],
D_class=EFFICIENTDET[network]['D_class'],
is_training=False
)
#self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.gpu],find_unused_parameters=True)
#self.model = torch.nn.parallel.DistributedDataParallel(self.model)
if(self.weights is not None):
print('load state dic...',self.weights)
checkpoint = torch.load(
self.weights, map_location=lambda storage, loc: storage)
state_dict = checkpoint['state_dict']
self.model.load_state_dict(state_dict)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model.eval()
这给出了密钥丢失错误。如何加载受DistributedDataParallel训练的pth文件?