如何在Keras的CNN上堆叠LSTM?

时间:2018-04-17 03:45:20

标签: tensorflow machine-learning deep-learning keras lstm

我为声音识别目的制作了以下神经网络模型。流程图如下所示:

cnn-lstm-dense-hybrid(please click here)

这个想法如下:

  1. 我有两个不同的输入图层,分别叫做A和B.

    (i)输入A有 100个时间步长,每个步骤都有一个64维特征向量

    (ii)1D CNN层(时间分布)将从每个时间步骤中提取特征。 CNN层包含 64个过滤器,每个过滤器长度为16个抽头。然后, maxpooling 图层将提取每个卷积输出的单个最大值,因此在每个时间步将提取总共64个要素

    (iii)CNN层的输出将被馈入具有64个神经元的 LSTM层。重复次数与输入的时间步长相同,即100个时间步长。 LSTM图层应该返回一个64维输出序列(序列长度==时间步数== 100,所以应该有100 * 64 = 6400个数字。)

    (iv)同时,输入B也有100个时间步长,每个步骤都有65维特征向量,但它们与输入A的处理方式不同。

    (v)输入B 被送入65个神经元的密集层(时间分布),因此它应该在每个时间步产生65维输出

  2. 现在,在每个时间步,我们有来自LSTM层(64个神经元)和密集层(65个神经元)的输出,我们在合并层中连接。现在我们在每个时间步都得到 129维向量

  3. 我们将此矢量输入另一个致密层,产生输出(单个神经元,表示“目标声音”的概率)
  4. A hand drawn illustration

    但是,我一开始就试图让1(i)工作。网络建设的代码如下:

    mfcc_input = Input(shape=(100,64), dtype='float', name='mfcc_input')
    print(mfcc_input)
    
    CNN_out = TimeDistributed(Conv1D(64, 16, activation='relu'))(mfcc_input)
    CNN_out = BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True)(CNN_out)
    CNN_out = TimeDistributed(MaxPooling1D(pool_size=(64-16+1), strides=None, padding='valid'))(CNN_out)
    CNN_out = Dropout(0.4)(CNN_out)
    
    
    LSTM_out = LSTM(64,return_sequences=True)(CNN_out)
    
    ## Auxilliary branch
    delta_input = Input(shape=(100,64), dtype='float', name='delta_input')
    zcr_input   = Input(shape=(100,1), dtype='float', name='zcr_input')
    aux_input   = concatenate([delta_input, zcr_input])
    aux_out     = TimeDistributed(Dense(64+1))(aux_input) 
    
    ### Merge branches
    merged_layer   = concatenate([LSTM_out, aux_out])
    
    ## Output layer
    output = TimeDistributed(Dense(1))(merged_layer)
    
    model = Model(inputs=[mfcc_input, delta_input, zcr_input], outputs=[output])
    
    model.compile(optimizer='rmsprop', loss='binary_crossentropy',
              loss_weights=[1., 0.2])
    ...(other code here) ...
    

    “CNN_out = TimeDistributed(Conv1D(64,16,activation ='relu'))(mfcc_input)”中的错误“是: IndexError:列表索引超出范围

    有人可以帮忙吗?非常感谢!

0 个答案:

没有答案