我正在寻找一些技巧来训练具有动态生成的bert嵌入的常规神经网络模型(BERT上下文化的嵌入为同一单词在不同的上下文中生成不同的嵌入)。
在正常的神经网络模型中,我们将使用手套或快速文本嵌入(例如,
import torch.nn as nn
embed = nn.Embedding(vocab_size, vector_size)
embed.weight.data.copy_(some_variable_containing_vectors)
我不想复制这样的静态向量并将其用于训练,我想将每个输入传递给BERT模型,并即时生成单词的嵌入,然后将其输入模型进行训练。
那么我应该在模型中更改前向功能以合并那些嵌入吗?
任何帮助将不胜感激!
答案 0 :(得分:5)
如果您使用的是Pytorch。您可以使用https://github.com/huggingface/pytorch-pretrained-BERT,它是Pytorch最受欢迎的BERT实现(它也是一个pip包!)。在这里,我将概述如何正确使用它。
对于此特定问题,有2种方法-您显然无法使用Embedding
层:
您可以编写一个循环来为这样的字符串生成BERT令牌(假设-因为BERT占用大量GPU内存):
(注意:为了更加适当,您还应该添加注意掩码-它们是1和0的LongTensor,用于掩盖句子的长度)
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel
batch_size = 32
X_train, y_train = samples_from_file('train.csv') # Put your own data loading function here
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
X_train = [tokenizer.tokenize('[CLS] ' + sent + ' [SEP]') for sent in X_train] # Appending [CLS] and [SEP] tokens - this probably can be done in a cleaner way
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model = bert_model.cuda()
X_train_tokens = [tokenizer.convert_tokens_to_ids(sent) for sent in X_train]
results = torch.zeros((len(X_test_tokens), bert_model.config.hidden_size)).long()
with torch.no_grad():
for stidx in range(0, len(X_test_tokens), batch_size):
X = X_test_tokens[stidx:stidx + batch_size]
X = torch.LongTensor(X).cuda()
_, pooled_output = bert_model(X)
results[stidx:stidx + batch_size,:] = pooled_output.cpu()
之后,您将获得包含计算出的嵌入的results
张量,您可以在其中将其用作模型的输入。
此方法的优点是不必每个时期都重新计算这些嵌入。
使用这种方法,例如,为了进行分类,您的模型应仅由Linear(bert_model.config.hidden_size, num_labels)
层组成,模型的输入应为上述代码中的results
张量
BertForSequenceClassification
)。实现自BertPretrainedModel
继承并利用仓库中各种Bert类的自定义类也应该很容易。例如,您可以使用:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', labels=num_labels) # Where num_labels is the number of labels you need to classify.
之后,您可以继续进行预处理,直到生成令牌ID。然后,您可以训练整个模型(但学习率较低,例如batch_size
= 32的Adam 3e-5)
通过这种方法,您可以自己微调BERT的嵌入,或者使用诸如冻结BERT的技术在几个时期内仅训练分类器,然后解冻以进行微调等。但这在计算上也更加昂贵。
中也提供了一个示例。