为什么在pytorch中简单的seq2seq总是返回NaN?

时间:2018-12-07 15:12:03

标签: pytorch seq2seq

我有以下精简型号:

import torch.nn as nn
import torch
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import json
import numpy as np
import datetime
import os


class EncoderRNN(nn.Module):
    def __init__(self, input_size=8, hidden_size=10, num_layers=2):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        #initialize weights
        nn.init.xavier_uniform(self.lstm.weight_ih_l0, gain=np.sqrt(2))
        nn.init.xavier_uniform(self.lstm.weight_hh_l0, gain=np.sqrt(2))

    def forward(self, input):
        tt = torch
        print input.shape
        h0 = Variable(tt.FloatTensor(self.num_layers, input.size(0), self.hidden_size))
        c0 = Variable(tt.FloatTensor(self.num_layers, input.size(0), self.hidden_size))
        encoded_input, hidden = self.lstm(input, (h0, c0))
        encoded_input = self.sigmoid(encoded_input)
        return encoded_input

train_x = torch.from_numpy(np.random.random((2000,19,8))).float()


train_loader = torch.utils.data.DataLoader(train_x,
    batch_size=64, shuffle=True)

model = EncoderRNN()

optimizer = optim.Adam(model.parameters(), lr=1e-6)

optimizer.zero_grad()


loss_function = torch.nn.BCELoss(reduce=True)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data_x) in enumerate(train_loader):
        x = model(Variable(data_x))
        print("output has nan: " + str(np.isnan(x.detach().numpy()).any()))


train(0)

总而言之,我认为我基本上只是将输入提供给具有随机初始化的隐藏值的LSTM,然后采用该LSTM输出的S形。然后,我将该输出馈送到解码器LSTM,并以解码器输出的输出的S形作为最终值。

不幸的是,即使在第一次迭代中,模型也经常输出正确形状的矢量(batch_size,seq_length,seq_dim),但至少包含一个NaN值,有时还包含所有NaN值。我在做什么错了?

谢谢!

到目前为止,我已经尝试过:

  • 将LSTM更改为GRU
  • 将sigmoid更改为relu
  • 更改隐藏表示的尺寸
  • 将失败的输入传递给编码器

编辑:对我在破坏代码后尝试提供帮助的所有人表示歉意-我真的很珍惜您的时间,并非常感谢您的帮助!

0 个答案:

没有答案