pytorch LSTM不会过度拟合单个样本

时间:2019-03-20 22:33:55

标签: python deep-learning lstm pytorch recurrent-neural-network

我尝试过拟合单个时间序列。意思是,我尝试一遍又一遍地在单个(X,Y)对上进行训练。我这样做是为了获得超参数的功能印象。但是它并没有收敛。这是损耗图,显示了大约800次迭代后每次迭代的MSE:

enter image description here

我希望该错误完全消失,但是当我写这篇文章时,它卡在了平稳的高原上。 时间序列的长度为29600,并且RNN将单个值映射到另一个单个值。它由一个带有1输入,200隐藏单元的LSTM单元和一个映射到单个值的完全连接的Layer组成

我的感觉告诉我,该模型可能不够复杂,无法适合样本。但是,在尝试增加RNN的复杂性之前,必须确保我对培训的实施正确。也许我没有正确使用autograd。由于这是第一次,我曾经尝试过训练神经网络,我不知道它需要多长时间。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable    
import numpy as np

# the class containing the LSTM
class Sequence(nn.Module):
    def __init__(self):
        self.hidden_state_len = 200
        super(Sequence, self).__init__()
        self.lstm1 = nn.LSTMCell(1, self.hidden_state_len)        
        self.linear = nn.Linear(self.hidden_state_len, 1)
        h_t = torch


    def forward(self, input):
        outputs = []
        h_t = torch.zeros(input.size(0), self.hidden_state_len, dtype=torch.double).cuda()
        c_t = torch.zeros(input.size(0), self.hidden_state_len, dtype=torch.double).cuda()

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):            
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            output = self.linear(h_t)
            outputs += [output]

        return torch.cat(outputs, dim=1)


x1 = torch.load("/floyd/input/wav/x1.pt").double().cuda()[0][7400:37000].reshape(1,-1)
y1 = torch.load("/floyd/input/wav/y1.pt").double().cuda()[0][7400:37000].reshape(1,-1)

seq = Sequence()
seq.double()

criterion = nn.MSELoss()
seq = seq.cuda()
device = torch.device("cuda:0")
seq = seq.to(device)

optimizer = optim.Adam(seq.parameters())
starttime = datetime.datetime.now()
i = -1

# training for 4 hours on cloud GPU
while((datetime.datetime.now() - starttime).total_seconds() < 60*60*4) :
    i+=1
    optimizer.zero_grad()
    input = Variable(x1)
    target = Variable(y1)
    out = seq(input)    
    loss = criterion(out, target)
    loss.backward()    
    optimizer.step()

0 个答案:

没有答案