PyTorch LSTM输入尺寸

时间:2019-06-08 12:35:12

标签: python machine-learning pytorch lstm

我正在尝试使用PyTorch LSTM训练一个简单的2层神经网络,而在解释Py​​Torch文档时遇到了麻烦。具体来说,我不太确定如何处理训练数据的形状。

我想做的是通过迷你批处理在非常大的数据集上训练我的网络,其中每批都说100个元素长。每个数据元素将具有5个功能。文档指出,该图层的输入应为形状(seq_len,batch_size,input_size)。我应该如何调整输入的形状?

我一直在关注这篇文章:https://discuss.pytorch.org/t/understanding-lstm-input/31110/3 如果我正确地解释了这一点,则每个小批量的形状应为(100,100,5)。但是在这种情况下,seq_len和batch_size有什么区别?另外,这是否意味着输入LSTM层的第一层应具有5个单位?

谢谢!

1 个答案:

答案 0 :(得分:1)

这是一个古老的问题,但是由于已被查看80多次而没有任何响应,因此让我对此进行解释。

LSTM网络用于预测序列。在NLP中,这将是一个单词序列;在经济学中,一系列经济指标;等

第一个参数是这些序列的长度。如果序列数据是由句子组成的,那么“汤姆的猫又黑又丑”是一个长度为7(seq_len)的序列,每个单词一个,或者可能是第8个序列,表示句子的结尾。

当然,您可能会反对“如果我的序列长度不同怎么办?”这是常见的情况。

两个最常见的解决方案是:

  1. 使用空元素填充序列。例如,如果最长的句子有15个单词,则将上面的句子编码为“ [Tom] [has] [a] [black] [and] [ugly] [cat] [EOS] [] [] [] [] [] [] []”,其中EOS代表句子结尾。突然,您所有的序列长度都变为15,这解决了您的问题。一旦找到[EOS]令牌,该模型就会迅速得知,它后面是无限制的空令牌序列[],这种方法几乎不会给您的网络增加负担。

  2. 发送相同长度的迷你批。例如,在所有句子上使用2个单词训练网络,然后使用3个单词,然后使用4个单词。当然,每个小批量的seq_len都会增加,并且每个小批量的大小将根据长度为N的序列数而变化您的数据中就有。

最好的方法是将数据分成大小大致相等的小批量,按近似长度将它们分组,并仅添加必要的填充。例如,如果您将长度为6、7和8的句子最小化在一起,则长度为8的序列将不需要填充,而长度为6的序列将仅需要填充2。如果您的大型数据集具有长度变化很大的序列,那是最好的方法。

但是,方法1是最简单(也是最懒惰的)方法,并且在小型数据集上效果很好。

最后一件事...始终在数据末尾而不是在开头填充数据。

我希望有帮助。