如何使BERT模型收敛?

时间:2020-03-18 00:32:20

标签: tensorflow keras nlp text-classification bert-language-model

我正在尝试使用BERT进行情绪分析,但我怀疑自己做错了什么。在我的代码中,我使用bert-for-tf2对bert进行了微调,但是在1个时期之后,当一个简单的GRU模型的准确度达到73%左右时,我的准确度为42%。为了有效使用BERT,我应该做些什么。我怀疑我要从第一批转换bert层,这可能是一个问题,因为密集层是随机初始化的。任何建议,将不胜感激,谢谢!

import bert-for-tf2 #gets imported as bert but relabeled for clarity
model_name = "uncased_L-12_H-768_A-12"
model_dir = bert.fetch_google_bert_model(model_name, ".models")
model_ckpt = os.path.join(model_dir, "bert_model.ckpt")

bert_params = bert.params_from_pretrained_ckpt(model_dir)
l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")
max_seq_len = 100
l_input_ids = tensorflow.keras.layers.Input(shape=(max_seq_len,), dtype='int32')

bertLayer = l_bert(l_input_ids)
flat = Flatten()(bertLayer)
output = Dense(1,activation = 'sigmoid')(flat)
model = tensorflow.keras.Model(inputs=l_input_ids, outputs=output)
model.build(input_shape=(None, max_seq_len))

bert.load_bert_weights(l_bert, model_ckpt) 

with open('../preprocessing_scripts/new_train_data.txt', 'r') as f:
  tweets = f.readlines()

with open('../preprocessing_scripts/targets.csv', 'r') as f:
  targets = f.readlines()

max_words = 14000
tokenizer = Tokenizer(num_words=max_words)

trainX = tweets[:6000]
trainY = targets[:6000]
testX = tweets[6000:]
testY = tweets[6000:]
maxlen = 100
tokenizer.fit_on_texts(trainX)

tokenized_version = tokenizer.texts_to_sequences(trainX)

tokenized_version = pad_sequences(tokenized_version, maxlen=maxlen)trainY = np.array(trainY,dtype = 'int32')
model.compile(loss="binary_crossentropy",
              optimizer="adam",
              metrics=['accuracy']) 

history = model.fit(x=tokenized_version, y=trainY, batch_size = 32, epochs=1, validation_split = 0.2)

1 个答案:

答案 0 :(得分:0)

我认为您的学习率LR(默认为adam:0.001)太高,导致catastrophic forgetting参考:如何为文本分类微调BERT? https://arxiv.org/pdf/1905.05583.pdf

理想地,LR应该是e-5的量级。尝试按以下方式更改代码,它应该可以工作

from keras_radam import RAdam

model.compile(
      RAdam(lr=2e-5),
      loss='binary_crossentropy',
      metrics=['accuracy'],
)