我搜索了很多资源来解决此问题,但仍然停留在这里。
我遵循了pytorch教程,并使用
保存了参数torch.save(the_model.state_dict(), PATH)
然后使用
加载参数the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
加载参数后,我打印了模型。它出来了一个错误。
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
我发现有些人面临同样的问题,但似乎可以忽略不计?!然后我尝试使用``model.forward()'''将数据输入到模型中,这又出现了另一个错误。
AttributeError: 'IncompatibleKeys' object has no attribute 'forward'
我知道这种保存方法(the_model.state_dict()
)只是保存“权重”。由于某些重要信息(例如,辍学,batchnorm等)未保存,因此只能使用.eval()
。所以我尝试model.eval()
,它仍然有相同的错误。
AttributeError: 'IncompatibleKeys' object has no attribute 'eval'
以下是一些相关代码:
初始化模型:
model = VAE(some constructor parameters)
训练后:
checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
torch.save(model.state_dict(), checkpoint_path)
初始化相同的模型并将参数加载到模型中:
model = VAE(some constructor parameters)
checkpoint = torch.load("E24.pkl", map_location='cuda:0')
model = model.load_state_dict(checkpoint)
我将不再训练该模型。我只想加载参数,然后检查性能。谢谢阅读。 :)