加载pytorch模型,测试准确性下降

时间:2018-08-10 08:02:55

标签: machine-learning pytorch

我有一个PyTorch模型,其测试准确度约为95%-97%。我使用torch.save(my_model.state_dict(),PATH)保存了该文件,但是每当我尝试使用my_model.load_state_dict(torch.load(PATH))重新加载它并使用test_fn(my_model)在相同的数据上对其进行测试时,我的测试精度都会下降到大约0.06%。我正在尝试遵循建议的序列化语义(https://pytorch.org/docs/stable/notes/serialization.html

无论我是否使用my_model.eval(),都会发生这种情况(尽管通过默认,我并未将其用于培训或测试)。我还需要采取其他步骤吗?

在代码中,它看起来像:

my_model = GraphConv(w2i, p2i, l2i, r2i, s2i, words, pos, lems, 512, 512, 3) ## Initialise model & params
my_model.cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(my_model.parameters(), lr=0.001)

for epoch in range(15):
    ...   ### Apply training steps
    print(test_fn(my_model))  ### Will be over 95%
    torch.save(my_model.state_dict(), PATH)

...
my_model2 = GraphConv(w2i, p2i, l2i, r2i, s2i, words, pos, lems, 512, 512, 3) ## Initialise new model
my_model2.load_state_dict(torch.load('PATH'))
print(test_fn(my_model2))  ### Is about 0.06%

0 个答案:

没有答案