Pytorch:在训练时可视化模型

时间:2019-08-14 11:55:33

标签: python python-3.x pytorch torch

我正在通过回归训练神经网络,但它在测试过程中会预测一个恒定值。这就是为什么我想在训练过程中可视化神经网络变化的权重,并在jupyter notebook中看到权重的动态变化。
目前,我的模型如下:

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.inp = nn.Linear(2, 40)
        self.act1 = nn.Tanh()
        self.h1 = nn.Linear(40, 40)
        self.act2 = nn.Tanh()
        self.h2 = nn.Linear(40, 2)
        self.act3 = nn.Tanh()
        #self.h3 = nn.Linear(20, 20)
        #self.act4=nn.Tanh()
        self.h4 = nn.Linear(2, 1)

    def forward_one_pt(self, x):
        out = self.inp(x)
        out = self.act1(out)
        out = self.h1(out)
        out = self.act2(out)
        out = self.h2(out)
        out = self.act3(out)
        #out = self.h3(out)
        #out = self.act4(out)
        out = self.h4(out)
        return out

    def forward(self, config):
        E = torch.zeros([config.shape[0], 1])
        for i in range(config.shape[0]):
            E[i] = self.forward_one_pt(config[i])
            # print("config[",i,"] = ",config[i],"E[",i,"] = ",E[i])
        return torch.sum(E, 0)

我的主要功能如下:

def main()  :
    learning_rate = 0.5
    n_pts = 1000
    t_pts = 100
    epochs = 15

    coords,E = load_data(n_pts,t_pts)

    #generating my data to NN
    G = get_symm(coords,save,load_symmetry,symmtery_pickle_file,eeta1,eeta2,Rs,ex,lambdaa,zeta,boxl,Rc,pi,E,scale)
    net = Net()
    if(cuda_flag):
        net.cuda()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    net_trained = train(save,text_output,epochs,n_pts,G,E,net,t_pts,optimizer,criterion,out,cuda_flag)
    test(save,n_pts,t_pts,G,E,net_trained,out,criterion,cuda_flag)

    torch.save(net,save_model)

任何教程或答案都会有所帮助

1 个答案:

答案 0 :(得分:0)

您可以使用model.state_dict()查看各个时期的权重是否正在更新:

old_state_dict = {}
for key in model.state_dict():
    old_state_dict[key] = model.state_dict()[key].clone()

output = model(input)

new_state_dict = {}
for key in model.state_dict():
    new_state_dict[key] = model.state_dict()[key].clone()

for key in old_state_dict:
    if not (old_state_dict[key] == new_state_dict[key]).all():
        print('Diff in {}'.format(key))
    else:
        print('NO Diff in {}'.format(key))

另一方面,您可以向量化前进功能,而不是对其进行循环。 Follow可以完成与原始转发功能相同的工作,但速度更快:

def forward(self, config):
    out= self.forward_one_pt(config)
    return torch.sum(out, 0)