我有这段代码用于训练去噪自动编码器,该编码器将LSTM用于编码器和解码器并按名称操作
def denoise_train(x: DataLoader):
loss = 0.
noisy_x = list(map(lambda s: noise_name(s), x))
rnn_x = to_rnn_tensor(x, DECODER_COUNT)
rnn_noisy_x = to_rnn_tensor(noisy_x, ENCODER_COUNT)
encoder_hidden = encoder.init_hidden(batch_size=BATCH_SZ)
for i in range(rnn_noisy_x.shape[0]):
_, encoder_hidden = encoder(rnn_noisy_x[i].unsqueeze(0), encoder_hidden)
decoder_input = strings_to_tensor([SOS] * BATCH_SZ)
decoder_hidden = encoder_hidden
name = ''
for i in range(rnn_x.shape[0]):
decoder_probs, decoder_hidden = decoder(decoder_input, decoder_hidden)
_, nonzero_indexes = rnn_x[i].topk(1)
# TODO!!! Need to fix rest of code for batch
best_index = torch.argmax(decoder_probs, dim=2).item()
loss += criterion(decoder_probs[0], nonzero_indexes[0])
name += ALL_CHARS[best_index]
decoder_input = torch.zeros(1, 1, LETTERS_COUNT)
decoder_input[0, 0, best_index] = 1.
loss.backward()
return name, noisy_x, loss.item()
在函数参数中传递的x是iter(DataLoader)的下一次迭代。我要做的主要事情是获取批处理时出现的所有decoder_prob的argmax,这是大小名称长度x批处理大小x输出长度。因此,我需要将best_index用作批处理中所有条目的argmax,并且coder_input应该是1xbatch大小x输出,其中所有最佳chars =1。如何在解码器_probs张量中的所有批处理中获得argmax?