我正在使用pytorch训练神经网络,我想在每次迭代时都保存权重。换句话说,我想创建一个列表,其中包含神经网络在训练过程中拥有的所有权重。
我做了以下事情:
for i, (images, labels) in enumerate(train_loader):
(.....code that is used to train the model here.....)
weight = model.fc2.weight.detach().numpy()
weights_list.append(weight)
当我随后打印列表'weights_list'的条目时,我注意到它们都是相同的,这是不对的,因为我在训练期间已经打印了权重并且它们确实发生了变化(并且网络确实可以学习,所以他们必须)。 我的猜测是,列表中的每个条目实际上都是指向检查列表时网络权重的指针。因此:
1)我的猜测正确吗? 2)我该如何解决这个问题?
谢谢!
答案 0 :(得分:0)
内置了保存和加载重量的功能。要保存到文件,可以使用
torch.save('checkpoint.pt', model.state_dict())
要加载,您可以使用
model.load_state_dict(torch.load('checkpoint.pt'))
也就是说,转换为numpy不一定会创建一个副本。例如,如果您有一个numpy数组y
,并且想要创建一个副本,则可以使用
x = numpy.copy(y)