the_model = TheModelClass(* args,** kwargs)是什么意思?

时间:2018-09-22 09:23:14

标签: python pytorch

我正在使用PyTorch进行图像分类。经过最多的训练后,我想保存训练后的模型。

我不明白这是什么意思

the_model = TheModelClass(*args, **kwargs)

这行代码是由PyTorch网站(https://pytorch.org/docs/master/notes/serialization.html)提供的。

1 个答案:

答案 0 :(得分:0)

这个问题是the_model = TheModelClass(*args, **kwargs) 意味着您必须首先定义一个 ModelClass 对象。然后你可以使用模型对象来加载磁盘顺序对象。例如:

in_feats = data.x.shape[1]
n_hidden = params["n_hidden"]
n_classes = 2
best_model = OwnGCN(in_c=in_feats, hid_c=n_hidden, out_c=n_classes)
best_model.load_state_dict(torch.load(PATH))