从PyTorch中的BiLSTM(BiGRU)获取最后一个状态

时间:2018-06-14 11:53:01

标签: python lstm pytorch

在阅读了几篇文章后,我仍然对从BiLSTM获取最后隐藏状态的实现的正确性感到困惑。

  1. Understanding Bidirectional RNN in PyTorch (TowardsDataScience)
  2. PackedSequence for seq2seq model (PyTorch forums)
  3. What's the difference between “hidden” and “output” in PyTorch LSTM? (StackOverflow)
  4. Select tensor in a batch of sequences (Pytorch formums)
  5. 来自最后一个源(4)的方法对我来说似乎是最干净的,但我仍然不确定我是否正确理解了该线程。我是否使用LSTM中正确的最终隐藏状态并反转LSTM?这是我的实施

    # pos contains indices of words in embedding matrix
    # seqlengths contains info about sequence lengths
    # so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and 
    # seqlengths contains [3,2], we have batch with samples
    # of variable length [4,6,9] and [3,1]
    
    all_in_embs = self.in_embeddings(pos)
    in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
    output,lasthidden = self.rnn(in_emb_seqs)
    if not self.data_processor.use_gru:
        lasthidden = lasthidden[0]
    # u_emb_batch has shape batch_size x embedding_dimension
    # sum last state from forward and backward  direction
    u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]
    

    这是对的吗?

3 个答案:

答案 0 :(得分:3)

在一般情况下,如果要创建自己的BiLSTM网络,则需要创建两个常规LSTM,并使用常规输入序列输入一个,而使用反向输入序列输入另一个。完成两个序列的喂食后,您只需从两个网络中取出最后一个状态,然后以某种方式将它们连接在一起(求和或连接)。

据我了解,您使用的内置BiLSTM与this example中一样(在 nn.LSTM 构造函数中设置bidirectional=True)。然后,在送入批处理后,您将得到连接的输出,因为PyTorch会为您处理所有麻烦。

如果是这种情况,并且您想要对隐藏状态求和,那么您必须

u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

假设您只有一个图层。如果您有更多图层,那么您的变体似乎更好。

这是因为结果是结构化的(见documentation):

  形状的

h_n num_layers * num_directions,batch,hidden_​​size ):包含t = seq_len的隐藏状态的张量

顺便说一下,

u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]

应该提供相同的结果。

答案 1 :(得分:3)

以下是对使用解压缩序列的人员的详细说明:

output的形状为(seq_len, batch, hidden_size * num_directions)(请参见documentation)。这意味着您的GRU向前和向后通过的输出沿第3维并置。

假设您的示例中为batch=2hidden_size=256,则可以通过以下操作轻松地分离正向和反向传递的输出:

output = output.view(-1, 2, 256, 2)   # (seq_len, batch_size, hidden_size, num_directions)
output_forward = output[:, :, :, 0]   # (seq_len, batch_size, hidden_size)
output_backward = output[:, :, :, 1]  # (seq_len, batch_size, hidden_size)

(注意:-1告诉pytorch从其他维度推断出该尺寸。请参见this问题。)

等效地,您可以使用torch.chunk函数

# Split in 2 tensors along dimension 3 (num_directions)
output_forward, output_backward = torch.chunk(output, 2, 3)

torch.chunk保留输入的维数,因此前向和后向张量的形状均为(seq_len, batch_size, hidden_size, 1)。要获得与第一种方法torch.squeeze完全相同的结果,请执行以下操作:

output_forward = output_forward.squeeze(3)   # (seq_len, batch_size, hidden_size)
output_backward = output_backward.squeeze(3) # (seq_len, batch_size, hidden_size)

现在,您可以使用seqlengths {重塑形状之后,torch.gather来{{3}}前进的最后一个隐藏状态,通过选择位置{{1}处的元素,可以{{3}}前进的最后一个隐藏状态。 }

0

请注意,由于基于0的索引,我从# First we unsqueeze seqlengths two times so it has the same number of # of dimensions as output_forward # (batch_size) -> (1, batch_size, 1) lengths = seqlengths.unsqueeze(0).unsqueeze(2) # Then we expand it accordingly # (1, batch_size, 1) -> (1, batch_size, hidden_size) lengths = lengths.expand((1, -1, output_forward.size(2))) last_forward = torch.gather(output_forward, 0, lengths - 1).squeeze(0) last_backward = output_backward[0, :, :] 中减去了1

此时lengthslast_forward的形状均为last_backward

答案 2 :(得分:1)

我测试了 biLSTM 输出和 h_n:

# shape of x is size(batch_size, time_steps, input_size)
# shape of output (batch_size, time_steps, hidden_size * num_directions)
# shape of h_n is size(num_directions, batch_size, hidden_size)
output, (h_n, _c_n) = biLSTM(x) 

print('step 0 of output from reverse == h_n from reverse?', 
    output[:, 0, hidden_size:] == h_n[1])
print('step -1 of output from reverse == h_n from reverse?', 
    output[:, -1, hidden_size:] == h_n[1])

输出

step 0 of output from reverse == h_n from reverse? True
step -1 of output from reverse == h_n from reverse? False

这证实了反向的h_n是第一个时间步的隐藏状态。

所以,如果你真的需要从正向和反向两个方向的最后一个时间步的隐藏状态,你应该使用:

sum_lasthidden = output[:, -1, :hidden_size] + output[:, -1, hidden_size:]

不是

h_n[0,:,:] + h_n[1,:,:]

As h_n[1,:,:] 是反方向第一个时间步的隐藏状态。

@igrinis 的回答

u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

不正确。

但理论上,反向的最后一个时间步隐藏状态只包含序列最后一个时间步的信息。