具有LSTM的Pytorch批处理训练降噪自动编码器

时间:2020-02-10 21:20:28

标签: pytorch lstm autoencoder

我有这段代码用于训练去噪自动编码器,该编码器将LSTM用于编码器和解码器并按名称操作

def denoise_train(x: DataLoader):
    loss = 0.
    noisy_x = list(map(lambda s: noise_name(s), x))

    rnn_x = to_rnn_tensor(x, DECODER_COUNT)
    rnn_noisy_x = to_rnn_tensor(noisy_x, ENCODER_COUNT)

    encoder_hidden = encoder.init_hidden(batch_size=BATCH_SZ)

    for i in range(rnn_noisy_x.shape[0]):
        _, encoder_hidden = encoder(rnn_noisy_x[i].unsqueeze(0), encoder_hidden)

    decoder_input = strings_to_tensor([SOS] * BATCH_SZ)

    decoder_hidden = encoder_hidden

    name = ''

    for i in range(rnn_x.shape[0]):

        decoder_probs, decoder_hidden = decoder(decoder_input, decoder_hidden)

        _, nonzero_indexes = rnn_x[i].topk(1)

        # TODO!!! Need to fix rest of code for batch

        best_index = torch.argmax(decoder_probs, dim=2).item()

        loss += criterion(decoder_probs[0], nonzero_indexes[0])

        name += ALL_CHARS[best_index]

        decoder_input = torch.zeros(1, 1, LETTERS_COUNT)

        decoder_input[0, 0, best_index] = 1.

    loss.backward()
    return name, noisy_x, loss.item()

在函数参数中传递的x是iter(DataLoader)的下一次迭代。我要做的主要事情是获取批处理时出现的所有decoder_prob的argmax,这是大小名称长度x批处理大小x输出长度。因此,我需要将best_index用作批处理中所有条目的argmax,并且coder_input应该是1xbatch大小x输出,其中所有最佳chars =1。如何在解码器_probs张量中的所有批处理中获得argmax?

0 个答案:

没有答案