Keras LSTM TimeDistributed,有状态

时间:2017-10-23 19:11:09

标签: keras lstm stateful

是否有详细解释TimeDistributed,stateful和return_sequences如何工作?在这两种情况下我都要设置shuffle = False吗?它适用于窗户(1-11,2-12,3-13等)还是应该分批使用(1-11,12-22,13-33等)

我对LSTM图层特别感兴趣。

1 个答案:

答案 0 :(得分:12)

TimeDistributed:

这不会影响图层的工作方式。 这样做的目的是增加一个"时间" (可能也不是时间)维度。考虑到该时间维度,将包裹的层应用于输入张量的每个切片。

例如,如果一个图层需要一个包含3个维度的输入形状,比如(batch, length, features),那么使用TimeDistributed包装将使其预期有4个维度:(batch, timeDimension, length, features)

然后将该层复制"并且同样适用于时间维度中的每个元素。

使用LSTM图层,它的工作原理相同。虽然LSTM图层已经预期其输入形状的时间维度为(batch, timeSteps, features),但您可以使用TimeDistributed添加另一个" time"维度(可能意味着任何事情,而不是确切时间),并使此LSTM图层可用于此新时间维度中的每个元素。

  • LSTM - 期望输入(batch, timeSteps, features)
  • TimeDistributed(LSTM()) - 期望输入(batch, superSteps, timeSteps, features)

在任何情况下,LSTM只会在timeSteps维度中实际执行其重复计算。另一个时间维度只是多次复制该图层。

TimeDistributed + Dense:

Dense图层(可能还有其他一些图层)已经支持3D输入,但标准是2D:(batch, inputFeatures)

使用TimeDistributed或不使用Dense图层是可选的,结果是相同的:如果您的数据是3D,则Dense图层将重复第二维。

返回序列:

documentation中已对此进行了详细解释。

对于循环图层,keras将使用timeSteps维度来执行其重复执行的步骤。对于每个步骤,它自然会有一个输出。

您可以选择获取所有步骤的输出(return_sequences=True)或仅获取最后一个输出(return_sequences=False

考虑像(batch, timeSteps, inputFeatures)这样的输入形状和具有outputFeatures单位的图层:

  • return_sequences=True时,输出形状为(batch, timeSteps, outputFeatures)
  • return_sequences=False时,输出形状为(batch, outputFeatures)

在任何情况下,如果您使用TimeDistributed包装器,superSteps维度将在输入和输出中保持不变。

Stateful = True

通常,如果您可以将所有序列的所有步骤放在输入数组中,那么一切都很好,您不需要stateful=True层。

Keras创造了一个"状态"对于批次中的每个序列。批量维度等于序列数。当keras完成处理批处理时,它会自动重置状态,这意味着:我们到达序列的结束(最后一步),从第一步开始带来新的序列。

使用stateful=True时,不会重置这些状态。这意味着将另一批次发送到模型不会被解释为一组新的序列,而是之前处理的序列的附加步骤。然后,您必须手动model.reset_states()告诉模型您已到达序列的最后一步,或者您将开始新的序列。

唯一需要shuffle=False的案例是stateful=True案例。因为对于每个批次,输入许多序列。在每批中,这些序列必须保持相同的顺序,以便每个序列的状态不会混合。

有状态层适用于:

  • 太大的数据。如果您一次使用所有时间步骤
  • ,它就不适合您的记忆
  • 您希望连续生成时间步骤并将这些新步骤作为输入添加到下一步,而不需要固定大小。 (您自己在代码中创建这些循环)
  • (其他用户的任何评论??)

使用Windows

到目前为止,我使用Windows的唯一方法是复制数据。

输入数组应该在windows中组织。每个窗口步骤一个序列。如果要将所有窗口步骤保留为单个批处理条目,则可以选择利用TimeDistributed包装器。但是你也可以将所有步骤都作为单独的序列。

由于状态,stateful=True图层无法使用Windows。如果您批量输入1到12的步骤,则下一批将期望步骤13作为保持连接的第一步。