这是Pytorch Seq2Seq模型的正确重新实现吗?

时间:2018-02-05 01:16:06

标签: python nlp deep-learning artificial-intelligence pytorch

我制作了一个代码来改变Pytorch提供的seq2seq的教程脚本。这是模型:

class Seq2Seq(nn.Module):
def __init__(self, encoder, batch_size, vocab_size, input_size, output_size, hidden_dim, embedding_dim, n_layers=2, dropout_p=0.5):
    super(Seq2Seq, self).__init__()

    self.hidden_dim = hidden_dim
    self.batch_size = batch_size
    self.input_length = input_size
    self.output_length = output_size
    self.vocab_size = vocab_size

    self.encoder = encoder
    self.dropout = nn.Dropout(dropout_p)
    self.selu = nn.SELU()
    self.decoder_embeddings = nn.Embedding(vocab_size, hidden_dim)
    self.decoder_gru = nn.GRU(hidden_dim, hidden_dim)
    self.out = nn.Linear(hidden_dim, vocab_size)
    self.softmax = nn.LogSoftmax()

def decode(self, SOS_token, encoder_hidden, target_output, teacher_forcing_ratio=0.8):
    decoder_output_full = autograd.Variable(torch.zeros(self.output_length, self.batch_size, self.vocab_size))
    decoder_output_full = decoder_output_full.cuda() if use_cuda else decoder_output_full
    target = target_output.permute(1,0)

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    for idx in range(self.output_length):
        if idx == 0:
            decoder_input = SOS_token
            decoder_hidden = encoder_hidden.unsqueeze(0)
        output = self.decoder_embeddings(decoder_input).view(1, self.batch_size, -1)
        output = self.dropout(output)

        output = self.selu(output)

        if use_teacher_forcing:
            decoder_output, decoder_hidden = self.decoder_gru(output, decoder_hidden)
            temp = 1
            out = self.out(decoder_output[0])
            out = out + sample_gumbel(out.shape)
            decoder_output = F.softmax(out / temp, dim=1)
            # decoder_output = (self.decoder_embeddings.weight * decoder_output.unsqueeze(1)).sum(0).view(1, 1, -1)
            decoder_output_full[idx, :, :] = decoder_output
            decoder_input = target[idx-1]  # Teacher forcing

        else:
            decoder_output, decoder_hidden = self.decoder_gru(output, decoder_hidden)
            temp = 1
            out = self.out(decoder_output[0])
            out = out + sample_gumbel(out.shape)
            decoder_output = F.softmax(out / temp, dim=1)
            # decoder_output = (self.decoder_embeddings.weight * decoder_output.unsqueeze(1)).sum(0).view(1, 1, -1)
            topv, topi = decoder_output.data.topk(1)
            # print topi
            ni = topi
            # decoder_input_v = autograd.Variable(torch.LongTensor([[ni]]))
            decoder_input = autograd.Variable(ni)
            # decoder_input = decoder_input.cuda() if use_cuda else decoder_input
            # print decoder_input
            decoder_output_full[idx, :, :] = decoder_output

    decoder_output_full = decoder_output_full.permute(1,0,2)

    # gen_output = self.softmax(self.out(decoder_output_full))

    return decoder_output_full

def forward(self, input, target_output, teacher_forcing_ratio=0.8):
    encoder_feat, _ = self.encoder(input)

    SOS_token = np.zeros((self.batch_size,1), dtype=np.int32)
    SOS_token = torch.LongTensor(SOS_token.tolist())
    SOS_token = autograd.Variable(SOS_token)
    if use_cuda:
        SOS_token = SOS_token.cuda(gpu)

    gen_output = self.decode(SOS_token, encoder_feat, target_output, teacher_forcing_ratio)

    return gen_output

def initHidden(self):
    result = autograd.Variable(torch.zeros(1, self.batch_size, self.hidden_dim))
    if use_cuda:
        return result.cuda()
    else:
        return result

我计算NLL损失的方法是先创建一个完整的输出序列,然后将其与目标输出进行比较。这是损失函数:

class batchNLLLoss(nn.Module):
def __init__(self):
    super(batchNLLLoss, self).__init__()

def forward(self, synt, target, claim_length=20):
    loss_fn = nn.NLLLoss()

    loss = 0

    for i in range(synt.shape[0]):
        for j in range(claim_length):
            loss += loss_fn(synt[i][j].unsqueeze(0), target[i][j])

    return loss

目前的问题是损失值真的很小,似乎网络什么都没学到(输出是同一个词一次又一次地重复)。有没有想过这个?提前谢谢!

0 个答案:

没有答案