在pytorch中,我们可以给出一个打包序列作为RNN的输入。从official doc开始,RNN的输入可以如下所示。
输入(seq_len,batch,input_size):包含输入序列功能的张量。输入也可以是打包的可变长度序列。
实施例
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.rnn(packed, hidden)
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs)
此处,embedded
是批量输入的嵌入式表示。
我的问题是,如何对RNN中的打包序列进行计算?如何通过打包表示计算批量填充序列的隐藏状态?
答案 0 :(得分:0)
基于this relevent question的matthew_zeng的答案:未计算填充元素的输出,隐藏将是最后一次有效输入后的隐藏状态。
答案 1 :(得分:0)
对于第二个问题:将不计算填充序列处的隐藏状态。
要回答这种情况是怎么发生的,首先让我们看看 pack_padded_sequence
的作用:
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
raw = [ torch.ones(25, 300) / 2,
torch.ones(22, 300) / 2.3,
torch.ones(15, 300) / 3.2 ]
padded = pad_sequence(raw) # size: [25, 3, 300]
lengths = torch.as_tensor([25, 22, 15], dtype=torch.int64)
packed = pack_padded_sequence(padded, lengths)
到目前为止,我们随机创建了一个具有不同长度的三个张量(在RNN的上下文中为时间步长),我们首先将它们填充到相同的长度,然后打包。现在,如果我们运行
print(padded.size())
print(packed.data.size()) # packed.data refers to the "packed" tensor
我们将看到:
torch.Size([25, 3, 300])
torch.Size([62, 300])
很明显62不是来自25 *3。因此,pack_padded_sequence
所做的只是根据我们传递给lengths
的{{1}}张量来保持每个批处理条目的有意义的时间步长(即如果我们将[25,25,25]传递给它,则pack_padded_sequence
的大小仍将是[75,300],即使原始张量不变。简而言之,rnn甚至不会看到带有pack_padded_sequence的填充时间步长
现在让我们看看将packed.data
和padded
传递给rnn之后的区别
packed
rnn = torch.nn.RNN(input_size=300, hidden_size=2)
padded_outp, padded_hn = rnn(padded) # size: [25, 3, 2] / [1, 3, 2]
packed_outp, packed_hn = rnn(packed) # 'PackedSequence' Obj / [1, 3, 2]
undo_packed_outp, _ = pad_packed_sequence(packed_outp)
# return "h_n"
print(padded_hn) # tensor([[[-0.2329, -0.6179], [-0.1158, -0.5430],[ 0.0998, -0.3768]]])
print(packed_hn) # tensor([[[-0.2329, -0.6179], [ 0.5622, 0.1288], [ 0.5683, 0.1327]]]
# the output of last timestep (the 25-th timestep)
print(padded_outp[-1]) # tensor([[[-0.2329, -0.6179], [-0.1158, -0.5430],[ 0.0998, -0.3768]]])
print(undo_packed_outp.data[-1]) # tensor([[-0.2329, -0.6179], [ 0.0000, 0.0000], [ 0.0000, 0.0000]]
和padded_hn
的值不同,因为rnn不会计算packed_hn
的填充,但不会计算“ packed”(PackedSequence对象)的填充,这也可以从上一个隐藏状态:padded
中的所有三个批处理条目的长度都小于25,即使上一个隐藏状态的长度都小于25,也是如此。但是对于padded
,不计算较短数据的上一个隐藏状态(即, 0)
p.s。另一个观察结果:
packed
将给我们print([(undo_packed_outp[:, i, :].sum(-1) != 0).sum() for i in range(3)])
,该数字与我们输入内容的实际长度一致。