训练神经网络时如何节省中间权重

时间:2019-11-08 16:11:45

标签: pointers pytorch

我正在使用pytorch训练神经网络,我想在每次迭代时都保存权重。换句话说,我想创建一个列表,其中包含神经网络在训练过程中拥有的所有权重。

我做了以下事情:

for i, (images, labels) in enumerate(train_loader):

     (.....code that is used to train the model here.....)

     weight = model.fc2.weight.detach().numpy()
     weights_list.append(weight)

当我随后打印列表'weights_list'的条目时,我注意到它们都是相同的,这是不对的,因为我在训练期间已经打印了权重并且它们确实发生了变化(并且网络确实可以学习,所以他们必须)。 我的猜测是,列表中的每个条目实际上都是指向检查列表时网络权重的指针。因此:

1)我的猜测正确吗? 2)我该如何解决这个问题?

谢谢!

1 个答案:

答案 0 :(得分:0)

内置了保存和加载重量的功能。要保存到文件,可以使用

torch.save('checkpoint.pt', model.state_dict())

要加载,您可以使用

model.load_state_dict(torch.load('checkpoint.pt'))

也就是说,转换为numpy不一定会创建一个副本。例如,如果您有一个numpy数组y,并且想要创建一个副本,则可以使用

x = numpy.copy(y)