我有一个PyTorch模型,其测试准确度约为95%-97%。我使用torch.save(my_model.state_dict(),PATH)保存了该文件,但是每当我尝试使用my_model.load_state_dict(torch.load(PATH))
重新加载它并使用test_fn(my_model)
在相同的数据上对其进行测试时,我的测试精度都会下降到大约0.06%。我正在尝试遵循建议的序列化语义(https://pytorch.org/docs/stable/notes/serialization.html)
无论我是否使用my_model.eval(),都会发生这种情况(尽管通过默认,我并未将其用于培训或测试)。我还需要采取其他步骤吗?
在代码中,它看起来像:
my_model = GraphConv(w2i, p2i, l2i, r2i, s2i, words, pos, lems, 512, 512, 3) ## Initialise model & params
my_model.cuda()
loss_function = nn.NLLLoss()
optimizer = optim.Adam(my_model.parameters(), lr=0.001)
for epoch in range(15):
... ### Apply training steps
print(test_fn(my_model)) ### Will be over 95%
torch.save(my_model.state_dict(), PATH)
...
my_model2 = GraphConv(w2i, p2i, l2i, r2i, s2i, words, pos, lems, 512, 512, 3) ## Initialise new model
my_model2.load_state_dict(torch.load('PATH'))
print(test_fn(my_model2)) ### Is about 0.06%