我尝试使用pyTorch读取检查点文件
checkpoint = torch.load(xxx.ckpt)
该文件是由使用python 2.7编写的程序生成的。我尝试使用python 3.6读取文件,但出现以下错误
UnicodeDecodeError: 'ascii' codec can't decode byte 0x8c in position 16: ordinal not in range(128)
是否可以在不降级python的情况下读取文件?
答案 0 :(得分:1)
在Python 2.x和Python 3.x之间,pickle
中存在一些兼容性问题,由于转向了unicode,您可能会将字符串保存为模型的一部分,因此这就是为什么错误。
您可以按照推荐的方法在Pytorch中保存模型并执行以下操作:
torch.save(filename, model.state_dict())
而不是保存model
。然后在Python3中:
model = Model() # construct a new model
model.load_state_dict(torch.load(filename))
另一种方法是在Python 2中解棘刺并将其保存为更易于在Python 2和Python 3之间传输的另一种格式。例如,您可以使用Pytorch-Numpy桥保存体系结构的张量并使用{{1 }}。
您也可以尝试使用np.savez
代替pickle
和tell it to decode ASCII strings to Python3 strings
答案 1 :(得分:0)
最终我通过以下方式解决了问题
1)使用anaconda创建python2环境
2)使用pytorch
读取检查点文件,然后使用pickle
checkpoint = torch.load("xxx.ckpt")
with open("xxx.pkl", "wb") as outfile:
pickle.dump(checkpointfile, outfile)
3)返回python3环境,使用pickle
读取文件,使用pytorch
保存文件
pkl_file = open("xxx.pkl", "rb")
data = pickle.load(pkl_file, encoding="latin1")
torch.save(data, "xxx.ckpt")