拥抱的面孔变形者:encode_plus中的截断策略

时间:2020-08-06 09:14:29

标签: pytorch huggingface-transformers

拥抱的转换器库中的

encode_plus允许截断输入序列。两个参数是相关的:truncationmax_length。我正在将成对的输入序列传递到encode_plus,并且需要简单地以“截止”方式截断输入序列,即,如果整个序列都包含两个输入text和{{1} }的长度比text_pair长,应从右侧将其相应地截断。

似乎没有一种截断策略允许这样做,而是max_length从最长序列中删除标记(可以是text或text_pair,而不仅仅是从序列的右边或结尾,例如,如果文本比text_pair长,则似乎会先从文本中删除标记),longest_firstonly_first仅从第一个或第二个标记中删除标记(因此,也不仅仅是从结尾处去除标记), only_second完全不会截断。还是我误解了这一点,实际上do_not_truncate可能正是我想要的?

1 个答案:

答案 0 :(得分:2)

没有longest_firstcut from the right不同。当您将截断策略设置为longest_first时,每次需要删除令牌时,令牌生成器都会比较texttext_pair的长度,并从最长的令牌中删除令牌。例如,这可能意味着它将从text_pair开始削减3个令牌,并将从texttext_pair交替切割的其余令牌削减。一个例子:

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

seq1 = 'This is a long uninteresting text'
seq2 = 'What could be a second sequence to the uninteresting text'

print(len(tokenizer.tokenize(seq1)))
print(len(tokenizer.tokenize(seq2)))

print(tokenizer(seq1, seq2))

print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))

输出:

9
13
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] this is a long unint [SEP] what could be a second sequence [SEP]

据您所知,您实际上正在寻找only_second,因为它是从右边(text_pair)开始切入的

print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))

输出:

{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

当您尝试将text输入的长度设置为指定的max_length时,将引发异常。我认为这是正确的,因为在这种情况下,它不再是序列对输入。

只要only_second不符合您的要求,您只需创建自己的截断策略即可。例如,手动only_second


tok_seq1 = tokenizer.tokenize(seq1)
tok_seq2 = tokenizer.tokenize(seq2)

maxLengthSeq2 =  myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
if len(tok_seq2) >  maxLengthSeq2:
    tok_seq2 = tok_seq2[:maxLengthSeq2]

input_ids = [tokenizer.cls_token_id] 
input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
input_ids += [tokenizer.sep_token_id]

token_type_ids = [0]*len(input_ids)

input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
input_ids += [tokenizer.sep_token_id]
token_type_ids += [1]*(len(tok_seq2)+1) 


attention_mask = [1]*len(input_ids)
print(input_ids)
print(token_type_ids)
print(attention_mask)

输出:

[101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]