当我尝试将经过gpu训练的模型加载到cpu时,这是一段代码:
model_conv.load_state_dict(torch.load(model_file, map_location='cpu'))
model_conv = model_conv.cpu()
,错误消息是:
Traceback (most recent call last):
File "prediction.py", line 269, in <module>
model_conv.load_state_dict(torch.load(resume_file, map_location='cpu'))
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 229, in load
return _load(f, map_location, pickle_module)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 377, in _load
result = unpickler.load()
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 348, in persistent_load
data_type(size), location)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 246, in restore_location
result = map_location(storage, location)
TypeError: 'str' object is not callable
我的pytorch版本是0.1.12_1
。任何想法如何解决这个问题?我已经检查过how to load the gpu trained model into the cpu?,但该解决方案似乎不适用于我的情况。
任何建议都值得赞赏!
-更新-
如果不使用map_location
参数,则错误消息为
Traceback (most recent call last):
File "prediction.py", line 268, in <module>
model_conv.load_state_dict(torch.load(resume_file))
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 229, in load
return _load(f, map_location, pickle_module)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 377, in _load
result = unpickler.load()
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 348, in persistent_load
data_type(size), location)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 85, in default_restore_location
result = fn(storage, location)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 67, in _cuda_deserialize
return obj.cuda(device_id)
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/_utils.py", line 57, in _cuda
with torch.cuda.device(device):
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/cuda/__init__.py", line 124, in __enter__
_lazy_init()
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/cuda/__init__.py", line 84, in _lazy_init
_check_driver()
File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/cuda/__init__.py", line 58, in _check_driver
http://www.nvidia.com/Download/index.aspx""")
AssertionError:
Found no NVIDIA driver on your system. Please check that you
have an NVIDIA GPU and installed a driver from
http://www.nvidia.com/Download/index.aspx