LSTM如何处理PyTorch中的压缩序列数据?

时间:2018-08-03 04:46:36

标签: lstm pytorch

我最近一直在从事图像字幕的工作。我遇到了在PyTorch中可以使用打包序列作为LSTM输入的代码,如本Link中所述。仍然有些混乱。 pack_padded_sequence返回一个带有两个参数的对象,第一个参数是张量,其中mini-batch的所有单词都作为行,而嵌入则作为列。毫无疑问。但是第二个参数只是一些我不理解的随机长度。第二张量的总和(其中所有元素的总和)等于迷你批处理中的单词总数。但是现在长度有所不同。它们不包含每个字幕的长度,而是包含一个以上字幕的长度。这些长度确切是什么?它们代表什么? Pytorch如何选择这些长度?此外,将其输入到LSTM中后,我们得到的结果是一样的,但是现在功能的数量仅仅是隐藏的大小。同样,LSTM的输出包含两个参数,即数据张量和相同的随机长度。在这种情况下,LSTM如何理解字幕的序列长度?

class DecoderRNN(nn.Module):

    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length

    def forward(self, features, captions, lengths):

        embeddings = self.embed(captions)    
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)   
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)      
        return hiddens

i = 0
for i, (images, captions, lengths) in enumerate(data_loader):
    if i==1:
        break
    images = images.to(device)
    captions = captions.to(device)
    targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
    features = encoder(images)
    outputs = decoder(features, captions, lengths)
    print(outputs[0].shape)
    print(outputs[1])
    print(outputs[1].sum())
    print(len(outputs[1]))
    print(lengths)
    i=1

输出为:

torch.Size([1717, 512])
tensor([ 128,  128,  128,  128,  128,  128,  128,  128,  128,  128,
         121,  104,   78,   46,   34,   20,   13,    9,    6,    2,
           2,    1,    1])
tensor(1717)
23
[23, 21, 19, 19, 19, 19, 18, 18, 18, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 16, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10, 10, 10, 10, 10]

Image

0 个答案:

没有答案