在Keras中批量处理可变长度的序列

时间:2018-10-28 08:18:11

标签: python tensorflow keras lstm rnn

假设我有一些形状为(?, 10)的数据。

也就是说,数据的特征是长度可变的序列,并且序列中的每个元素都由10个数字表示。

我们希望将其提供给LSTM,因此我们将每个批次的长度准备为(32, m, 10),其中m是批次中示例的最大序列长度。

序列长度小于m的批次中的示例用零填充。

现在,我们要将其提供给LSTM,并希望LSTM停止更新填充输入上的输出。

在Tensorflow中,这将通过其dynamic_rnn的参数sequence_length完成。

我如何在Keras中获得相同的效果?

1 个答案:

答案 0 :(得分:1)

您需要使用Masking,它会产生一个 mask ,该掩码允许LSTM跳过那些填充的值。从文档中:

model = Sequential()
model.add(Masking(mask_value=0., input_shape=(None, 10)))
model.add(LSTM(32))

上面的LSTM现在将跳过所有10个功能都填充为0的时间步。 注意::如果您要返回序列,则将返回先前的隐藏状态:

x = [2, 0, 1, 0] # for example
# will produce
y = [h1, h1, h2, h2] # so 0 pads are skipped
# but returning sequences repeats the last hidden state for masked values