如何将注意层应用于LSTM模型

时间:2020-11-09 07:53:39

标签: python tensorflow keras attention-model

我正在做语音情感识别机器培训。

我希望对模型应用关注层。指令page很难理解。

def bi_duo_LSTM_model(X_train, y_train, X_test,y_test,num_classes,batch_size=68,units=128, learning_rate=0.005, epochs=20, dropout=0.2, recurrent_dropout=0.2):
    
    class myCallback(tf.keras.callbacks.Callback):

        def on_epoch_end(self, epoch, logs={}):
            if (logs.get('acc') > 0.95):
                print("\nReached 99% accuracy so cancelling training!")
                self.model.stop_training = True

    callbacks = myCallback()

    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Masking(mask_value=0.0, input_shape=(X_train.shape[1], X_train.shape[2])))
    model.add(tf.keras.layers.Bidirectional(LSTM(units, dropout=dropout, recurrent_dropout=recurrent_dropout,return_sequences=True)))
    model.add(tf.keras.layers.Bidirectional(LSTM(units, dropout=dropout, recurrent_dropout=recurrent_dropout)))
    #     model.add(tf.keras.layers.Bidirectional(LSTM(32)))
    model.add(Dense(num_classes, activation='softmax'))

    adamopt = tf.keras.optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
    RMSopt = tf.keras.optimizers.RMSprop(lr=learning_rate, rho=0.9, epsilon=1e-6)
    SGDopt = tf.keras.optimizers.SGD(lr=learning_rate, momentum=0.9, decay=0.1, nesterov=False)

    model.compile(loss='binary_crossentropy',
                  optimizer=adamopt,
                  metrics=['accuracy'])

    history = model.fit(X_train, y_train,
                        batch_size=batch_size,
                        epochs=epochs,
                        validation_data=(X_test, y_test),
                        verbose=1,
                        callbacks=[callbacks])

    score, acc = model.evaluate(X_test, y_test,
                                batch_size=batch_size)

    yhat = model.predict(X_test)

    return history, yhat

如何应用它以适合我的模型?

use_scalecausaldropout都是参数吗?

如果dropout中有attention layer,由于LSTM层中有dropout,我们该如何处理?

1 个答案:

答案 0 :(得分:0)

注意力可以解释为软向量检索。

  • 您有一些查询向量。对于每个查询,您要检索一些

  • ,以便您计算它们的权重

  • ,其中权重是通过将查询与进行比较而获得的(键的数量必须与值的数量相同,并且它们通常是相同的向量)。

在序列到序列模型中,查询是解码器状态,而键和值是解码器状态。

在分类任务中,您没有这样的显式查询。解决此问题的最简单方法是训练“通用”查询,该查询用于从隐藏状态中收集相关信息(类似于最初描述的in this paper)。

如果您将问题标记为序列标签,而不是将标签分配给整个序列,而是分配给各个时间步长,则可能需要使用自关注层。