在Pytorch中加载权重时会遇到一些问题

时间:2019-10-25 08:25:53

标签: python pytorch

我搜索了很多资源来解决此问题,但仍然停留在这里。

我遵循了pytorch教程,并使用

保存了参数
torch.save(the_model.state_dict(), PATH)

然后使用

加载参数
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

加载参数后,我打印了模型。它出来了一个错误。

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

我发现有些人面临同样的问题,但似乎可以忽略不计?!然后我尝试使用``model.forward()'''将数据输入到模型中,这又出现了另一个错误。

AttributeError: 'IncompatibleKeys' object has no attribute 'forward'

我知道这种保存方法(the_model.state_dict())只是保存“权重”。由于某些重要信息(例如,辍学,batchnorm等)未保存,因此只能使用.eval()。所以我尝试model.eval(),它仍然有相同的错误。

AttributeError: 'IncompatibleKeys' object has no attribute 'eval'

以下是一些相关代码:

  • 初始化模型:

    model = VAE(some constructor parameters)

  • 训练后:

    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    
  • 初始化相同的模型并将参数加载到模型中:

    model = VAE(some constructor parameters)
    checkpoint = torch.load("E24.pkl", map_location='cuda:0')
    model = model.load_state_dict(checkpoint)
    

我将不再训练该模型。我只想加载参数,然后检查性能。谢谢阅读。 :)

0 个答案:

没有答案