encode_plus
允许截断输入序列。两个参数是相关的:truncation
和max_length
。我正在将成对的输入序列传递到encode_plus
,并且需要简单地以“截止”方式截断输入序列,即,如果整个序列都包含两个输入text
和{{1} }的长度比text_pair
长,应从右侧将其相应地截断。
似乎没有一种截断策略允许这样做,而是max_length
从最长序列中删除标记(可以是text或text_pair,而不仅仅是从序列的右边或结尾,例如,如果文本比text_pair长,则似乎会先从文本中删除标记),longest_first
和only_first
仅从第一个或第二个标记中删除标记(因此,也不仅仅是从结尾处去除标记), only_second
完全不会截断。还是我误解了这一点,实际上do_not_truncate
可能正是我想要的?
答案 0 :(得分:2)
没有longest_first
与cut from the right
不同。当您将截断策略设置为longest_first
时,每次需要删除令牌时,令牌生成器都会比较text
和text_pair
的长度,并从最长的令牌中删除令牌。例如,这可能意味着它将从text_pair
开始削减3个令牌,并将从text
和text_pair
交替切割的其余令牌削减。一个例子:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
seq1 = 'This is a long uninteresting text'
seq2 = 'What could be a second sequence to the uninteresting text'
print(len(tokenizer.tokenize(seq1)))
print(len(tokenizer.tokenize(seq2)))
print(tokenizer(seq1, seq2))
print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))
输出:
9
13
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] this is a long unint [SEP] what could be a second sequence [SEP]
据您所知,您实际上正在寻找only_second
,因为它是从右边(text_pair
)开始切入的
print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))
输出:
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
当您尝试将text
输入的长度设置为指定的max_length时,将引发异常。我认为这是正确的,因为在这种情况下,它不再是序列对输入。
只要only_second
不符合您的要求,您只需创建自己的截断策略即可。例如,手动only_second
:
tok_seq1 = tokenizer.tokenize(seq1)
tok_seq2 = tokenizer.tokenize(seq2)
maxLengthSeq2 = myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
if len(tok_seq2) > maxLengthSeq2:
tok_seq2 = tok_seq2[:maxLengthSeq2]
input_ids = [tokenizer.cls_token_id]
input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
input_ids += [tokenizer.sep_token_id]
token_type_ids = [0]*len(input_ids)
input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
input_ids += [tokenizer.sep_token_id]
token_type_ids += [1]*(len(tok_seq2)+1)
attention_mask = [1]*len(input_ids)
print(input_ids)
print(token_type_ids)
print(attention_mask)
输出:
[101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]