我正在使用keras构建神经网络,我对LSTM层输入形状有些迷茫。下面是相关部分的图片。
两个塔都是相似的,唯一的区别是左侧接受任何长度的序列,而右侧仅接受长度5的序列。这导致它们的LSTM层分别接收不明确的序列长度和4的序列长度,两者都具有每步8个功能。因此,我希望两个LSTM层都应具有(1,8)的input_shape。
我现在感到困惑的是,两个LSTM层都可以毫无问题地接受任何输入形状,这就是为什么我认为这可能无法按我认为的方式工作的原因。我希望正确的LSTM层要求输入形状的第一维为1、2或4,因为只有这些尺寸才能划分输入序列4。此外,我希望两者都都需要第二维永远是8。
有人可以解释一下为什么LSTM层可以接受任何输入形状,以及它们是否使用input_shape =(1,8)正确处理序列?下面是相关代码。
# Tower 1
inp_sentence1 = Input(shape=(None, 300, 1))
conv11 = Conv2D(32, (2, 300))(inp_sentence1)
reshape11 = K.squeeze(conv11, 2)
maxpl11 = MaxPooling1D(4, data_format='channels_first')(reshape11)
lstm11 = LSTM(units=6, input_shape=(1,8))(maxpl11)
# Tower 2
inp_sentence2 = Input(shape=(5, 300, 1))
conv21 = Conv2D(32, (2, 300))(inp_sentence2)
reshape21 = Reshape((4,32))(conv21)
maxpl21 = MaxPooling1D(4, data_format='channels_first')(reshape21)
lstm21 = LSTM(units=6, input_shape=(1,8))(maxpl21)
编辑:伪数据中问题的简短再现:
# Tower 1
inp_sentence1 = Input(shape=(None, 300, 1))
conv11 = Conv2D(32, (2, 300))(inp_sentence1)
reshape11 = K.squeeze(conv11, 2)
maxpl11 = MaxPooling1D(4, data_format='channels_first')(reshape11)
lstm11 = LSTM(units=6, input_shape=(1,8))(maxpl11)
# Tower 2
inp_sentence2 = Input(shape=(5, 300, 1))
conv21 = Conv2D(32, (2, 300))(inp_sentence2)
reshape21 = Reshape((4,32))(conv21)
maxpl21 = MaxPooling1D(4, data_format='channels_first')(reshape21)
lstm21 = LSTM(units=6, input_shape=(1,8))(maxpl21)
# Combine towers
substract = Subtract()([lstm11, lstm21])
dense = Dense(16, activation='relu')(substract)
final = Dense(1, activation='sigmoid')(dense)
# Build model
model = Model([inp_sentence1, inp_sentence2], final)
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# Create data
random_length = random.randint(2, 10)
x1 = numpy.random.random((100, random_length, 300))
x2 = numpy.random.random((100, 5, 300))
y = numpy.random.randint(2, size=100)
# Train and predict on data
model.fit([x1, x2], y, epochs=10, batch_size=5)
prediction = model.predict([x1, x2])
prediction = [round(x) for [x] in prediction]
classification = prediction == y
print("accuracy:", sum(classification)/len(prediction))