用于NLP的RNN Keras模型在训练时需要花费大量时间,并且不会减少验证损失

时间:2019-06-21 07:55:36

标签: keras deep-learning nlp lstm recurrent-neural-network

我已经建立了用于实体识别的RNN模型。我使用BERT嵌入,然后通过RNN模型处理结果。但是,在训练模型5个时间段时,每个时间段似乎需要大约2个小时。而且,验证损失似乎根本没有减少。

我正在RTX 2080 GPU上运行该过程。我已经尝试过操纵模型,但是没有改进模型。我拥有的数据集大约有40万个句子。

这是我的模特

def build_model(max_seq_length, n_tags): 
    in_id = Input(shape=(max_seq_length,), name="input_ids")
    in_mask = Input(shape=(max_seq_length,), name="input_masks")
    in_segment = Input(shape=(max_seq_length,), name="segment_ids")

    bert_inputs = [in_id, in_mask, in_segment]   
    bert_output = BertLayer(n_fine_tune_layers=3, pooling="first")(bert_inputs)
    x = RepeatVector(max_seq_length)(bert_output)
    x = Bidirectional(LSTM(units=lstm_units, return_sequences=True,
                           recurrent_dropout=0.2, dropout=0.2))(x)
    x_rnn = Bidirectional(LSTM(units=lstm_units, return_sequences=True,
                               recurrent_dropout=0.2, dropout=0.2))(x)
    x = add([x, x_rnn])  # residual connection to the first biLSTM
    pred = TimeDistributed(Dense(n_tags, activation="softmax"))(x)

    model = Model(inputs=bert_inputs, outputs=pred)
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model

这是模型摘要:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_ids (InputLayer)          (None, 30)           0                                            
__________________________________________________________________________________________________
input_masks (InputLayer)        (None, 30)           0                                            
__________________________________________________________________________________________________
segment_ids (InputLayer)        (None, 30)           0                                            
__________________________________________________________________________________________________
bert_layer_3 (BertLayer)        ((None, 30), 768)    110104890   input_ids[0][0]                  
                                                                 input_masks[0][0]                
                                                                 segment_ids[0][0]                
__________________________________________________________________________________________________
repeat_vector_2 (RepeatVector)  ((None, 30), 30, 768 0           bert_layer_3[0][0]               
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) ((None, 30), 30, 200 695200      repeat_vector_2[0][0]            
__________________________________________________________________________________________________
bidirectional_3 (Bidirectional) ((None, 30), 30, 200 240800      bidirectional_2[0][0]            
__________________________________________________________________________________________________
add_1 (Add)                     ((None, 30), 30, 200 0           bidirectional_2[0][0]            
                                                                 bidirectional_3[0][0]            
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib ((None, 30), 30, 3)  603         add_1[0][0]                      
==================================================================================================
Total params: 111,041,493
Trainable params: 22,790,811
Non-trainable params: 88,250,682
__________________________________________________________________________________________________

日志:

 32336/445607 [=>............................] - ETA: 2:12:59 - loss: 0.3469 - acc: 0.9068
 32352/445607 [=>............................] - ETA: 2:12:58 - loss: 0.3469 - acc: 0.9068
 32368/445607 [=>............................] - ETA: 2:12:58 - loss: 0.3469 - acc: 0.9068

您能帮我找出我要去哪里吗?

1 个答案:

答案 0 :(得分:2)

如果使用Bert进行嵌入,则输出形状应为(None, 30, 768)。但是您的Bert模型返回一个(None, 768)张量,然后使用RepeatVector复制它。我猜您正在从Bert中提取[CLS]的输出。请从Bert模型中提取正确的图层。

之所以要进行三叉戟的时间如此之久,仅仅是因为对于每个时期,您都需要将所有数据通过巨大的bert模型传递,即使您冻结了大多数层。