训练LSTM时跑出Ram

时间:2019-11-12 20:21:38

标签: python memory pytorch lstm torch

我是RNN的初学者,所以我使用Pytorch编写了LSTM体系结构,但是只要我处于第三个时代,我就总是用光RAM。我已经在使用DataLoader了,我试图从输入张量中分离出渐变,但是并不能解决问题。

这是我的训练循环

writer = SummaryWriter()
criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index = 0)
optimizer = optim.Adam(lstm.parameters(), lr = 1e-5)
gradient_clip = clip_grad_norm_(lstm.parameters(), max_norm = 5)

num_epochs = 20
epoch_loss = -1.0
loss = - 1

t = trange(num_epochs, desc= "Epoch loss",  leave=True)

for epoch in t:
    trainLoader = iter(DataLoader(dataset, batch_size = batch_size))

    tt = trange(len(trainLoader)-1,  desc= "Batch loss",  leave=True)

    for i in tt:

        text, embedding = next(trainLoader)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize

        y = lstm.forward(embedding.transpose(1,0))

        labels = text.transpose(0,1)[1:].transpose(0,1).flatten()

        loss = criterion(y.reshape(-1, y.shape[-1]), labels)


        tt.set_description("Batch loss : %.4f" % loss)

        tt.refresh()

        loss.backward(retain_graph=True)

        optimizer.step()

        epoch_loss += loss


    epoch_loss = epoch_loss / (len(trainLoader) - 1)

    # Saving model
    save_date = datetime.now().strftime("%d%m%Y-%H:%M:%S")
    PATH = './save/lstm_model_'+save_date
    torch.save(lstm, PATH)

    # Updating progression bar

    t.set_description("Epoch loss : %.4f" % epoch_loss)

    t.refresh()

    # Plotting gradients histograms in Tensorboard

    writer.add_scalar('Text_generation_Loss/train', epoch_loss, epoch)

    for tag, parm in lstm.named_parameters():

        with torch.no_grad():

            writer.add_histogram(tag, parm.grad.data.cpu().numpy(), epoch)


    writer.flush()

print('Finished Training')

writer.close()

这是我构建的LSTM类:

class LSTM(nn.Module):

    def __init__(self, in_size : int, hidden_size : int):
        super().__init__()
        self.in_size = in_size
        self.hidden_size = hidden_size
        self.W_fi = nn.Linear(in_size,hidden_size)
        self.W_fh = nn.Linear(hidden_size,hidden_size, bias=False)
        self.W_ii = nn.Linear(in_size,hidden_size)
        self.W_ih = nn.Linear(hidden_size,hidden_size, bias=False)
        self.W_Ci = nn.Linear(in_size,hidden_size)
        self.W_Ch = nn.Linear(hidden_size,hidden_size, bias=False)
        self.W_oi = nn.Linear(in_size,hidden_size)
        self.W_oh = nn.Linear(hidden_size,hidden_size, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def one_step(self, x, h, C):
        f_t = self.sigmoid(self.W_fi(x) + self.W_fh(h))
        i_t = self.sigmoid(self.W_ii(x) + self.W_ih(h))
        g_t = self.tanh(self.W_Ci(x) + self.W_Ch(h))
        C_t = torch.mul(f_t, C) + torch.mul(i_t, g_t)
        o_t = self.sigmoid(self.W_oi(x) + self.W_oh(h))
        h_t = torch.mul(o_t, self.tanh(C_t))
        return h_t, C_t

    def forward(self, X):
        h_out = []
        h = - torch.ones(X.shape[1], self.hidden_size)
        C = - torch.ones(X.shape[1], self.hidden_size)
        h_t, C_t = self.one_step(X[0], h, C)
        h_out.append(h_t)

        for i in range(1, X.shape[0]  - 1):
            h_t, C_t = self.one_step(X[i], h_t, C_t)
            h_out.append(h_t)
        h_out = torch.cat(h_out)

        return h_out #h_out.reshape(-1,batch_size,num_embeddings)

我已经搜索过类似的案例,但找不到解决方法

1 个答案:

答案 0 :(得分:0)

我不知道这是否对某人有帮助,但是我解决了这个问题。我可能不清楚该任务,但目标是生成文本。我要做的第一件事是使用在LSTM之外定义的torch.nn.embedding嵌入句子。解决方案是将其包括在我的网络中,因为嵌入不是经过预训练的,因此也应该学习。