如何为DistributedDataparallel模型加载pth文件?

时间:2020-11-01 03:13:14

标签: deep-learning parallel-processing pytorch

我使用DistributedDataParallel训练了模型,并制作了pth文件

  if args.gpu is not None:
            print('Gpu setting...',args.gpu)
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu]
                #,output_device=[args.gpu]
                ,find_unused_parameters=True)

然后我尝试评估模型

 self.model = EfficientDet(num_classes=num_class,
                                  network=network,
                                  W_bifpn=EFFICIENTDET[network]['W_bifpn'],
                                  D_bifpn=EFFICIENTDET[network]['D_bifpn'],
                                  D_class=EFFICIENTDET[network]['D_class'],
                                  is_training=False
                                  )

        #self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.gpu],find_unused_parameters=True)
        #self.model =  torch.nn.parallel.DistributedDataParallel(self.model)

        if(self.weights is not None):
            print('load state dic...',self.weights)
            checkpoint = torch.load(
                self.weights, map_location=lambda storage, loc: storage)
            state_dict = checkpoint['state_dict']
            self.model.load_state_dict(state_dict)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        self.model.eval()

这给出了密钥丢失错误。如何加载受DistributedDataParallel训练的pth文件?

0 个答案:

没有答案