我有一本字典(有10k +个单词)和一段(有10M +个单词)。我想用<unk>
替换字典中没有出现的所有单词。
我尝试了str.maketrans
,但是它的密钥应该是一个字符。
然后我尝试了此https://stackoverflow.com/a/40348578/5634636,但是正则表达式非常慢。
有更好的解决方案吗?
答案 0 :(得分:2)
我们将问题分为两部分:
passage
,找到索引i,其中passage[i]
不在另一个单词列表dictionary
中。<unk>
放在那些索引上。 1中需要进行主要工作。为此,我们首先将字符串列表转换为2D numpy数组,以便我们可以高效地执行操作。另外,我们对二进制搜索中下面需要的字典进行排序。另外,我们用0填充字典,以使其与passage_enc
的列数相同。
# assume passage, dictionary are initially lists of words
passage = np.array(passage) # np array of dtype='<U4'
passage_enc = passage.view(np.uint8).reshape(-1, passage.itemsize)[:, ::4] # 2D np array of size len(passage) x max(len(x) for x in passage), with ords of chars
dictionary = np.array(dictionary)
dictionary = np.sort(dictionary)
dictionary_enc = dictionary.view(np.uint8).reshape(-1, dictionary.itemsize)[:, ::4]
pad = np.zeros((len(dictionary), passage_enc.shape[1] - dictionary_enc.shape[1]))
dictionary_enc = np.hstack([dictionary_enc, pad]).astype(np.uint8)
然后,我们仅遍历段落,并检查字符串(现在是数组)是否在字典中。它将取O(n * m),n,m分别是段落和字典的大小。 但是,我们可以通过预先对字典进行排序并在其中进行二进制搜索来改进此功能。因此,它变为O(n * logm)。
此外,我们JIT编译代码以使其更快。在下面,我使用numba。
import numba as nb
import numpy as np
@nb.njit(cache=True) # cache as being used multiple times
def smaller(a, b):
n = len(a)
i = 0
while(i<n and a[i] == b[i]):
i+=1
if(i==n):
return False
return a[i] < b[i]
@nb.njit(cache=True)
def bin_index(array, item):
first, last = 0, len(array) - 1
while first <= last:
mid = (first + last) // 2
if np.all(array[mid] == item):
return mid
if smaller(item, array[mid]):
last = mid - 1
else:
first = mid + 1
return -1
@nb.njit(cache=True)
def replace(dictionary, passage):
unknown_indices = []
n = len(passage)
for i in range(n):
ind = bin_index(dictionary, passage[i])
if(ind == -1):
unknown_indices.append(i)
return unknown_indices
检查样本数据
import nltk
emma = nltk.corpus.gutenberg.words('austen-emma.txt')
passage = np.array(emma)
passage = np.repeat(passage, 50) # bloat coprus to have around 10mil words
passage_enc = passage.view(np.uint8).reshape(-1, passage.itemsize)[:, ::4]
persuasion = nltk.corpus.gutenberg.words('austen-persuasion.txt')
dictionary = np.array(persuasion)
dictionary = np.sort(dictionary) # sort for binary search
dictionary_enc = dictionary.view(np.uint8).reshape(-1, dictionary.itemsize)[:, ::4]
pad = np.zeros((len(dictionary), passage_enc.shape[1] - dictionary_enc.shape[1]))
dictionary_enc = np.hstack([dictionary_enc, pad]).astype(np.uint8) # pad with zeros so as to make dictionary_enc and passage_enc of same shape[1]
段落和字典的大小,最终出于计时目的而达到OP要求的顺序。这个电话:
unknown_indices = replace(dictionary_enc, passage_enc)
在我的8核,16 G系统上花费了17.028s(包括预处理时间,显然不包括加载语料库的时间)。
然后,这很简单:
passage[unknown_indices] = "<unk>"
P.S:我想,通过在parallel=True
的njit装饰器中使用replace
,我们可以提高速度。我遇到了一些奇怪的错误,如果我能够对其进行解决,则会进行编辑。