在pytorch中保存具有更新权重的模型

时间:2019-07-15 12:58:11

标签: python neural-network deep-learning save pytorch

我有一个要训练5个时代的模型。然后,我想看看模型哪里出了问题,并相应地增加了训练集。如何用学习的重量保存以下模型?

trainer_ = Trainer(network = network,
                       optimizer = optim.Adam(network.parameters(), lr=0.001),
                       loss_function = loss_function,
                       train_loader = train_loader,
                       valid_every = 100,
                       print_every = 50,
                       save_every = 15000,
                       save_path = ".",
                       cudaok = is_cuda_available)

trainer_.run(4,is_cuda_available)

我已经尝试过:

path = os.path.join(project_path, 'model.pth')
torch.save(network.cpu().state_dict(), path) # saving model

但是我真的不认为对象network包含权重。

我在这里很困惑。有人可以帮忙吗?谢谢!

2 个答案:

答案 0 :(得分:1)

<table width="100%" border="0" cellspacing="0" cellpadding="0"> <tr> <td style="text-align: center;"> <table width="500px" cellspacing="0" cellpadding="0" border="0" style="border-collapse: collapse; margin: 0 auto;"> <tbody> <tr> <td> <a href=""><img style="padding: 0;" src="https://ac-image.s3.amazonaws.com/6/7/7/4/3/7/home/admin/new_email_templates/group-1.jpg?r=715459984"> </a> </td> <td> <a href=""><img style="padding: 0;" src="https://ac-image.s3.amazonaws.com/6/7/7/4/3/7/home/admin/new_email_templates/group_2.jpg?r=1551841180"> </a> </td > <td> <a href=""><img style="padding: 0;" src="https://ac-image.s3.amazonaws.com/6/7/7/4/3/7/home/admin/new_email_templates/group.jpg?r=269367714"> </a> </td> </tr> </tbody> </table> </tr> </td> </table> network.state_dict();尝试一下以查看您的体重:

dictionary

答案 1 :(得分:0)

您正在正确保存模型。 现在要将权重加载到模型中,您可以使用以下参数创建一个新模型:

network = Network(*args, **kwargs)

,然后将保存的权重加载到其中:

network.load_state_dict(torch.load(path))