我正在使用一个github存储库,其中包含训练有素的CNN,其权重参数在.npy
文件中给出。模型正在加载权重,并使用像这样的模型参数:-
model = CNN_Model(batch_size)
filename = "weight_file.npy"
dtype = torch.FloatTensor
model.load_state_dict(load_weights(model, weight_file, dtype))
load_weights
被定义为:-
def load_weights(model, filename, dtype):
model_params = model.state_dict()
data_dict = np.load(filename, encoding='latin1').item()
model_params["conv1.weight"] = torch.from_numpy(data_dict["conv1"] ["weights"]).type(dtype).permute(3,2,0,1)
model_params["conv1.bias"] = torch.from_numpy(data_dict["conv1"]["biases"]).type(dtype)
model_params["bn1.weight"] = torch.from_numpy(data_dict["bn_conv1"]["scale"]).type(dtype)
model_params["bn1.bias"] = torch.from_numpy(data_dict["bn_conv1"]["offset"]).type(dtype)
return model_params
我已经在其中添加了一个训练模块,并试图微调我自己的数据集上的权重。训练后,我想将新的权重保存在.npy
的文件中,索引的索引为data_dict
,与先前加载的权重文件中的索引相同,因此我可以将其再次用于CNN模型。
在使用以下方法保存data_dict数组之前,应如何使用相似的名称进行索引编制:
np.save("trained_weight_file.npy", data_dict)
编辑1:- 因此,根据我推荐的{@ a-d
data_dict = model.state_dict()
它所做的是它保存了索引为model_params
的所有权重。 print data_dict
的输出为:-
OrderedDict([('conv1.weight', tensor([[[[....]]]])), ('conv1.bias', tensor([....])), , ('bn1.weight', tensor([....])), ('bn1.bias', tensor([....]))])
但是我需要存储在data_dict
索引中,以便我可以使用相同的算法从.npy
文件中读取它。我也尝试从data_dict
定义中返回model_params
和load_weights
,然后尝试使用data_dict = model.state_dict()
,但是它给了我`model.load_state_dict(load_weights(model,weight_file, dtype))'行,即:-
回溯(最近通话最近): model.load_state_dict(load_weights(model,weight_file,dtype)) state_dict = state_dict.copy() AttributeError:“ tuple”对象没有属性“ copy”
答案 0 :(得分:0)
我会做类似data_dict = model.state_dict()
的事情。
您可以阅读带有state_dict()
here输出示例的官方文档。
有一个github repository是github存储库的基础,您可以从中获取代码。该存储库也使用model.state_dict()
来存储值。