batch_first会影响Pytorch LSTM中的隐藏张量吗?

时间:2018-04-25 00:15:16

标签: lstm pytorch

batch_first会影响Pytorch LSTM中隐藏的张量吗?

即如果batch_first参数为真, 隐藏状态是(numlayer*direction,num_batch,encoding_dim)(num_batch,numlayer*direction,encoding_dim)

我已经测试了两者,都没有给出任何错误。

3 个答案:

答案 0 :(得分:1)

前段时间我正在考虑同样的问题。像laydog概述的那样,在文档中说

  

batch_first - 如果为True,则提供输入和输出张量   as(batch,seq,feature)

据我所知,我们正在讨论隐藏/单元状态元组的问题,而不是实际的输入和输出。

对我来说,这似乎很明显,这并不影响他们提到的隐藏状态:

  

(批次,序列,功能)

这明确指的是输入和输出,而不是由两个形状元组组成的状态元组:

  

(num_layers * num_directions,batch,hidden_​​size)

所以我很确定隐藏的和单元格状态不受此影响,对我来说改变顺序隐藏状态元组也没有意义。

希望这有帮助。

答案 1 :(得分:0)

来自docs

  

batch_first - 如果为True,则输入和输出张量提供为(batch,seq,feature)

所以,是的,如果您的输入是批量优先,那么输出也将是批量优先。

答案 2 :(得分:0)

让我们看一个例子

batch_size = 8
sequence_length = 10
input_dim = 64
lstm = nn.LSTM(input_size=input_dim, hidden_size=32, num_layers=1, batch_first=True, bidirectional=False)
lstm_input = torch.randn(batch_size, sequence_length, input_dim)
output, (hidden, cell) = lstm(lstm_input)
# >>> output.shape, hidden.shape, cell.shape
# (torch.Size([8, 10, 32]), torch.Size([1, 8, 32]), torch.Size([1, 8, 32]))

因此我们可以看到output变量是批处理优先的,而隐藏状态和单元格的状态不是。

当您想要以初始隐藏或单元格状态进食时,事情会变得更加复杂

您可能期望可以正常工作:

initial_cell = torch.randn(batch_size, 1, 32)
initial_hidden = torch.randn(batch_size, 1, 32)
# WARNING, THIS DOES NOT WORK, IT IS AN EXAMPLE!
# >>> output, (hidden, cell) = lstm(lstm_input, (initial_hidden, initial_cell))
# RuntimeError: Expected hidden[0] size (1, 8, 32), got [8, 1, 32]

这意味着输入的隐藏状态和单元格状态必须为(sequence length, batch, hidden dim)格式。

initial_cell = torch.randn(1, batch_size, 32)
initial_hidden = torch.randn(1, batch_size, 32)
output, (hidden, cell) = lstm(lstm_input, (initial_hidden, initial_cell))
# >>> output.shape, hidden.shape, cell.shape
# (torch.Size([8, 10, 32]), torch.Size([1, 8, 32]), torch.Size([1, 8, 32]))

因此,我们可以看到,无论batch_first,隐藏状态和单元格状态始终是(seq, batch, dim)格式,无论它是LSTM单元格的输入还是输出参数。

GRU的隐藏状态也是如此。