我正在尝试加载一个预先训练过的模型
model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'}
当我使用以下代码时,它总是将模型加载到cuda:0。如果我想将它加载到cuda:3?
,该怎么办?model = ResNet(BasicBlock, [3, 4, 6, 3])
device = 3
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'],
map_location=lambda storage, loc: storage.cuda(device)))
答案 0 :(得分:0)
这应该适合你:
device = torch.device('cuda')
model = ResNet(BasicBlock, [3, 4, 6, 3])
with torch.cuda.device(3):
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'],
map_location=lambda storage, loc: storage.cuda(device)))
我认为这适用于版本0.4.0及更高版本,您可以查看0.4.0中的其他一些示例。迁移指南: https://pytorch.org/2018/04/22/0_4_0-migration-guide.html