PyTorch网络测试代码无效

时间:2017-05-09 23:31:19

标签: pytorch

我正在尝试训练一个简单的MLP来近似y = f(a,b,c)。 我的代码如下。

import torch
import torch.nn as nn
from torch.autograd import Variable
# hyper parameters
input_size = 3
output_size = 1
num_epochs = 50
learning_rate = 0.001

# Netork definition
class FeedForwardNet(nn.Module):
    def __init__(self, l1_size, l2_size):
      super(FeedForwardNet, self).__init__()
      self.fc1 = nn.Linear(input_size, l1_size) 
      self.relu1 = nn.ReLU()
      self.fc2 = nn.Linear(l1_size, l2_size) 
      self.relu2 = nn.ReLU()
      self.fc3 = nn.Linear(l2_size, output_size)

    def forward(self, x):
      out = self.fc1(x)
      out = self.relu1(out)
      out = self.fc2(out)
      out = self.relu2(out)
      out = self.fc3(out)
      return out

model = FeedForwardNet(5 , 3)  

# sgd optimizer
optimizer = torch.optim.SGD(model.parameters(), learning_rate, momentum=0.9)

for epoch in range(11):
    print ('Epoch ', epoch)
    for i in range(trainX_light.shape[0]):
    X = Variable( torch.from_numpy(trainX_light[i]).view(-1, 3) )
    Y = Variable( torch.from_numpy(trainY_light[i]).view(-1, 1) )
    # forward
    optimizer.zero_grad()
    output = model(X)

    loss = (Y - output).pow(2).sum()
    print (output.data[0,0])
    loss.backward()
    optimizer.step()
    totalnorm = 0
    for p in model.parameters():
        modulenorm = p.grad.data.norm()
        totalnorm += modulenorm ** 2
        totalnorm = math.sqrt(totalnorm)

     print (totalnorm)

    # validation code
    if (epoch + 1) % 5 == 0:
    print (' test points',testX_light.shape[0])
    total_loss = 0
    for t in range(testX_light.shape[0]):
        X = Variable( torch.from_numpy(testX_light[t]).view(-1, 3) )
        Y = Variable( torch.from_numpy(testY_light[t]).view(-1, 1) )
        output = model(X)
        loss = (Y - output).pow(2).sum()
        print (output.data[0,0])
        total_loss += loss
    print ('epoch ', epoch, 'avg_loss ', total_loss.data[0] / testX_light.shape[0])

print ('Done')

我现在遇到的问题是验证码

  

输出=模型(X)

始终产生完全相同的输出值(我猜这个值是某种垃圾)。我不确定我在这部分做了什么错。有人可以帮我弄清楚代码中的错误吗?

1 个答案:

答案 0 :(得分:0)

回答我自己的问题。网络产生随机值(以及后来的inf)的原因是爆炸梯度问题。剪切渐变grep -nHIirE -- ([^,\n\s']),([^,\n\s\)\]']) 有帮助。