我正在使用PyTorch进行图像分类。经过最多的训练后,我想保存训练后的模型。
我不明白这是什么意思
the_model = TheModelClass(*args, **kwargs)
这行代码是由PyTorch网站(https://pytorch.org/docs/master/notes/serialization.html)提供的。
答案 0 :(得分:0)
这个问题是the_model = TheModelClass(*args, **kwargs)
意味着您必须首先定义一个 ModelClass 对象。然后你可以使用模型对象来加载磁盘顺序对象。例如:
in_feats = data.x.shape[1]
n_hidden = params["n_hidden"]
n_classes = 2
best_model = OwnGCN(in_c=in_feats, hid_c=n_hidden, out_c=n_classes)
best_model.load_state_dict(torch.load(PATH))