在训练pytorch RNN时,损失并没有减少

时间:2018-04-01 18:25:45

标签: python machine-learning pytorch rnn

这是我为情绪设计的RNN网络。

class rnn(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.h2h = nn.Linear(hidden_size , hidden_size)
        self.relu = nn.Tanh()
        self.sigmoid = nn.LogSigmoid()

    def forward(self, input, hidden):
        hidden_new = self.relu(self.i2h(input)+self.h2h(hidden))
        output = self.h2o(hidden)
        output = self.sigmoid(output)
        return output, hidden_new

    def init_hidden(self):
        return Variable(torch.zeros(1, self.hidden_size))

然后,我创建并训练网络:

RNN = rnn(50, 50, 1)
learning_rate = 0.0005
criteria = nn.MSELoss()
optimizer = optim.Adam(RNN.parameters(), lr=learning_rate)
hidden = RNN.init_hidden()
epochs = 2
for epoch in range(epochs):
    for i in range(len(train['Phrase'])):
        input = convert_to_vectors(train['Phrase'][i])
        for j in range(len(input)):
            temp_input = Variable(torch.FloatTensor(input[j]))
            output, hidden = RNN(temp_input, hidden)
        temp_output = torch.FloatTensor([np.float64(train['Sentiment'][i])/4])
        loss = criteria( output, Variable(temp_output))
        loss.backward(retain_graph = True)
        if (i%20 == 0):
            print('Current loss is ', loss)

问题在于网络的丢失并没有减少。它会增加,然后减少,等等。它根本不稳定。我尝试使用较小的学习率,但似乎没有帮助。

为什么会发生这种情况,我该如何纠正?

2 个答案:

答案 0 :(得分:1)

您只需在执行optimizer.step()之后致电loss.backward()

顺便说一下,这说明了一个常见的误解:反向传播不是学习算法,它只是计算损耗w.r.t的梯度的一种很酷的方法。您的参数。然后,您可以使用某些梯度下降方式(例如,普通的SGD,AdaGrad等,在您的情况下为Adam)来更新给定梯度的权重。

答案 1 :(得分:0)

我认为有些事情可能会给你一些帮助。 首先,在rnn类模块中,您最好使用"super(rnn,self).__init__()"来替换"super().__init__()"

其次,变量名称应与函数一致,您最好使用"self.tanh = nn.Tanh()"替换"self.relu = nn.Tanh()"。在rnn中,sigmoid函数应该是1/(1+exp(-x)),而不是logsigmoid函数。您应该使用"self.sigmoid = nn.Sigmoid()"替换"self.sigmoid = nn.LogSigmoid()"。 第三,如果使用rnn进行分类,则应通过softmax函数激活输出。因此,您应添加两个语句"self.softmax = nn.Softmax()""output = self.softmax(output)"