我最近一直在从事图像字幕的工作。我遇到了在PyTorch中可以使用打包序列作为LSTM输入的代码,如本Link中所述。仍然有些混乱。 pack_padded_sequence返回一个带有两个参数的对象,第一个参数是张量,其中mini-batch的所有单词都作为行,而嵌入则作为列。毫无疑问。但是第二个参数只是一些我不理解的随机长度。第二张量的总和(其中所有元素的总和)等于迷你批处理中的单词总数。但是现在长度有所不同。它们不包含每个字幕的长度,而是包含一个以上字幕的长度。这些长度确切是什么?它们代表什么? Pytorch如何选择这些长度?此外,将其输入到LSTM中后,我们得到的结果是一样的,但是现在功能的数量仅仅是隐藏的大小。同样,LSTM的输出包含两个参数,即数据张量和相同的随机长度。在这种情况下,LSTM如何理解字幕的序列长度?
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.max_seq_length = max_seq_length
def forward(self, features, captions, lengths):
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
hiddens, _ = self.lstm(packed)
return hiddens
i = 0
for i, (images, captions, lengths) in enumerate(data_loader):
if i==1:
break
images = images.to(device)
captions = captions.to(device)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
features = encoder(images)
outputs = decoder(features, captions, lengths)
print(outputs[0].shape)
print(outputs[1])
print(outputs[1].sum())
print(len(outputs[1]))
print(lengths)
i=1
输出为:
torch.Size([1717, 512])
tensor([ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
121, 104, 78, 46, 34, 20, 13, 9, 6, 2,
2, 1, 1])
tensor(1717)
23
[23, 21, 19, 19, 19, 19, 18, 18, 18, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 16, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10, 10, 10, 10, 10]