如何建模一个1D CNN + LSTM网络,其中每个时间步长是一组1D阵列?

时间:2019-07-03 00:08:53

标签: keras lstm

我正在尝试使用一组一维CNN和LSTM层对基于Keras的网络进行建模。网络上大多数可用的示例都使用(1, 30, 50)这样的形状的数据(1个样本包含30个时间步长,每个时间步长都有50个特征)。

但是,数据集中的每个时间步都是由许多一维数组组成的。 10个时间步的样本将是(1, 10, 100, 384)(1个批次-一个样本,10个时间步,每个时间步包含100个具有384个特征的数组)。那么,如何定义具有这种形状的模型?

我确实可以将每个时间步长数据(100*384)弄平,但这似乎是不够的,因为可能会使所有的CNN处理无效...另外,每个时间步长数据实际上都是一维的: it不是空间数据

我已经定义了一个简单的模型,如下所示,但是我认为它错误地使用了输入形状的batch_size。我认为它尝试从“ 482个样本”中学习,而不是从具有“ 482个时间步长”的单个样本中学习...

data_input_shape = (482, 100, 384)

model = Sequential()
model.add(Conv1D(300, 1, activation="relu", input_shape=(100,384)))
model.add(MaxPooling1D(4))
model.add(Conv1D(256, 1, activation="relu"))
model.add(MaxPooling1D(4))
model.add(Conv1D(128, 1, activation="relu"))
model.add(MaxPooling1D(5))

model.add(LSTM(200, return_sequences=True))
model.add(LSTM(200, return_sequences=True))
model.add(LSTM(200, return_sequences=True))
model.add(Dense(1, activation='sigmoid'))

有什么建议吗?

1 个答案:

答案 0 :(得分:0)

让我们假设以下两种情况,因为您已经提到100个数组在空间上不相关:

  1. 每个要素的384个值在空间上无关。
  2. 每个要素的384个值与空间有关。例如,它们是经过FFT或类似操作后在整个频率范围内的值。

在情况1中,您基本上具有100x384独立功能。因此,扁平化似乎是一个不错的选择。

尽管如此,在情况2中,对要素应用2D卷积可能是有意义的。方法如下:

首先,您应该以正确的格式准备数据。假设您的数据有482个时间步长,则应确定每个样本中希望有多少个时间步长。例如,您可以决定每个样本中有10个时间步长,而这些样本之间没有重叠,则将为您提供约48个样本。因此,数据现在将具有形状(48、10、100、384)。另外,我们应该添加一个额外的尺寸作为通道,以便能够在Keras中应用2D卷积。这样您的数据就可以变成(48、10、100、384、1)

接下来,您可以决定体系结构。我们将在每个时间步骤将Conv2D应用于每个数组。由于您的数组在空间上不相关,因此我们使用的内核大小为(1,x)或(100,x)。这是一个示例架构:

model = Sequential()
model.add(TimeDistributed(Conv2D(16, (1, 5), activation="relu"), input_shape=(10, 100, 384, 1)))
model.add(TimeDistributed(MaxPooling2D((1, 2))))
model.add(TimeDistributed(Conv2D(32, (100, 9), activation="relu"), input_shape=(10, 100, 384, 1)))
model.add(TimeDistributed(MaxPooling2D((1, 4))))
model.add(TimeDistributed(Flatten()))
model.add(LSTM(16, return_sequences=True))
model.add(Dense(1, activation='sigmoid'))

一些附加说明:

  • 您当然可以为每种类型添加更多层。
  • TimeDistruted是上面的新内容。您可以here来了解它。
  • 如果要开始使用图像,请考虑从一开始就使用CNN / LSTM hybrid或Conv3D,而不是从图像中提取100个数组。
  • 看看有关CNN和LSTM组合层的ConvLSTM2D here