pytorch LSTM回归;损失函数在每个时期都达到峰值

时间:2019-11-04 10:55:03

标签: regression pytorch lstm

我正在自学pytorch,并尝试 LSTM 来解决剩余生命周期回归(RUL)问题。

我的输入矩阵由矢量中的特征变量组成,这些矢量沿行向下堆积[ 2D时间矩阵(按特征矢量)],我在输入中输入了经过预处理的DataLoader每个enumerate(train_loader)

的[batchsize,序列长度,特征数]的大小

The training output of RUL for 50th epoch

RUL of the whole run

我正在尝试使用MSELoss,带有RMSprop的L1Loss,SGD优化器之类的损失函数,目标是沿着时间步长简单地降低RUL值。

因此问题是,在每个时期,随着新时期的进行,损耗计算从头开始。这可能是由于目标的高整数值引起的高损耗值。但是,随着时代的重复,由于学习已经完成,因此损失的起始值预计会下降。但是结果却不是我所期望的。

希望有人可以帮助我解决这个问题。 预先谢谢你.. !!!

遵循模型源:

class GRU(nn.Module):
    def __init__(self, input_dim, hidden_dim, batch_size, window_size, output_dim=1, num_layers=2):
        super(GRU, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.window_size = window_size

        # Define the GRU layer
        self.gru1 = nn.GRU(
                input_size=self.input_dim, 
                hidden_size=hidden_dim,
                bias=True)

        # Define the output layer
        self.linear1 = nn.Linear(hidden_dim, 500)
        self.linear2 = nn.Linear(500, 1)

    def forward(self, input, hidden):
        input=input.transpose(1,0)
        gru_out, self.hidden1 = self.gru1(input.view(self.window_size, self.batch_size, -1), hidden)

        y_pred = self.linear1(gru_out[-1].view(self.batch_size, -1))
        y_pred = self.linear2(y_pred)

        return y_pred.view(-1), self.hidden1.detach()

以下是我的火车来源:

#input shape: [sequence_length, batch_size, input_size]
#target shape: [batch_size, 1]

for j in range(epoch):
    for i, (dat, target) in enumerate(train_loader):

        out, hidden = model(dat, hidden)
        loss = criterion(out, target.view(-1))
        loss.backward(retain_graph=True)

        optimizer.step()
        optimizer.zero_grad()

        loss_history.append(loss.item())

        if i % 100 == 0:    
            print("{} epoch: {}/{}timestep, time: {}, loss {}:".format(j, i,len(train_loader), round(t2-t1,4), loss))
        RUL=torch.cat((RUL,out.data.view(-1)))

    plt.plot(RUL[-28000:].detach().numpy(), linewidth=0.5)
    plt.show()

0 个答案:

没有答案