将CNN的输出传递到BILSTM

时间:2020-09-08 08:14:06

标签: python tensorflow keras lstm cnn

我正在一个项目中,我必须将CNN的输出传递到双向LSTM。我按如下方式创建了模型,但抛出了“不兼容”错误。请让我知道我要去哪里哪里以及如何解决这个问题


    model = Sequential()
    model.add(Conv2D(filters = 16, kernel_size = 3,input_shape = (32,32,1)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2),strides=1, padding='valid'))
    model.add(Activation('relu'))
    
    model.add(Conv2D(filters = 32, kernel_size=3))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Activation('relu'))
    
    model.add(Dropout(0.25))
    model.add(Conv2D(filters = 48, kernel_size=3))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Activation('relu'))
    
    model.add(Dropout(0.25))
    model.add(Conv2D(filters = 64, kernel_size=3))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    
    model.add(Dropout(0.25))
    model.add(Conv2D(filters = 80, kernel_size=3))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    
    model.add(Bidirectional(LSTM(150, return_sequences=True)))
    model.add(Dropout(0.3))
    model.add(Bidirectional(LSTM(96)))
    model.add(Dense(total_words/2, activation='relu', kernel_regularizer=regularizers.l2(0.01)))
    model.add(Dense(total_words, activation='softmax'))
    
    model.summary()

返回的错误是:


    ValueError                                Traceback (most recent call last)
    <ipython-input-24-261befed7006> in <module>()
         27 model.add(Activation('relu'))
         28 
    ---> 29 model.add(Bidirectional(LSTM(150, return_sequences=True)))
         30 model.add(Dropout(0.3))
         31 model.add(Bidirectional(LSTM(96)))
    
    5 frames
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name)
        178                          'expected ndim=' + str(spec.ndim) + ', found ndim=' +
        179                          str(ndim) + '. Full shape received: ' +
    --> 180                          str(x.shape.as_list()))
        181     if spec.max_ndim is not None:
        182       ndim = x.shape.ndims
    
    ValueError: Input 0 of layer bidirectional is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, 1, 1, 80]

2 个答案:

答案 0 :(得分:0)

Conv2D具有2维输入/输出,但LSTM具有1维输入。这就是为什么它期望3维(批,序列,隐藏)但找到4维(批,X,Y,隐藏)的原因。解决方案是例如在CNN之后和LSTM之前使用Flatten层将输出投影到一维序列。

答案 1 :(得分:0)

问题是传递到LSTM的数据,可以在您的网络内部解决。 LSTM需要3D数据。您可以采用两种可能性: 1)进行重塑(batch_size, H, W*channel) 2) (batch_size, W, H*channel)。这样,您就可以在LSTM中使用3D数据。下面的例子

def ReshapeLayer(x):
    
    shape = x.shape
    
    # 1 possibility: H,W*channel
    reshape = Reshape((shape[1],shape[2]*shape[3]))(x)
    
    # 2 possibility: W,H*channel
    # transpose = Permute((2,1,3))(x)
    # reshape = Reshape((shape[1],shape[2]*shape[3]))(transpose)
    
    return reshape

total_words = 300
model = Sequential()
model.add(Conv2D(filters = 16, kernel_size = 3,input_shape = (32,32,1)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2),strides=1, padding='valid'))
model.add(Activation('relu'))

model.add(Conv2D(filters = 32, kernel_size=3))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Activation('relu'))

model.add(Dropout(0.25))
model.add(Conv2D(filters = 48, kernel_size=3))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Activation('relu'))

model.add(Dropout(0.25))
model.add(Conv2D(filters = 64, kernel_size=3))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Dropout(0.25))
model.add(Conv2D(filters = 80, kernel_size=3))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Lambda(ReshapeLayer)) # <============

model.add(Bidirectional(LSTM(150, return_sequences=True)))
model.add(Dropout(0.3))
model.add(Bidirectional(LSTM(96)))
model.add(Dense(total_words/2, activation='relu'))
model.add(Dense(total_words, activation='softmax'))

model.summary()