这是一篇自我回答的帖子。下面我概述NLP域中的一个常见问题,并提出一些高效的方法来解决它。
通常需要在文本清理和预处理期间删除标点符号。标点符号定义为string.punctuation
中的任何字符:
>>> import string
string.punctuation
'!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
这是一个很常见的问题,并且在恶心之前就被问过了。最惯用的解决方案是使用pandas str.replace
。但是,对于涉及 lot 文本的情况,可能需要考虑更高性能的解决方案。
在处理数十万条记录时,str.replace
有哪些优秀,高效的替代方案?
答案 0 :(得分:47)
为了演示,让我们考虑一下这个DataFrame。
df = pd.DataFrame({'text':['a..b?!??', '%hgh&12','abc123!!!', '$$$1234']})
df
text
0 a..b?!??
1 %hgh&12
2 abc123!!!
3 $$$1234
下面,我逐一列出替代方案,按性能递增顺序
str.replace
此选项用于将默认方法建立为比较其他更高性能解决方案的基准。
这使用pandas内置str.replace
函数执行基于正则表达式的替换。
df['text'] = df['text'].str.replace(r'[^\w\s]+', '')
df
text
0 ab
1 hgh12
2 abc123
3 1234
这很容易编码,并且可读性很强,但速度很慢。
regex.sub
这涉及使用sub
库中的re
函数。为性能预编译正则表达式模式,并在列表解析中调用regex.sub
。如果你可以节省一些内存,事先将df['text']
转换为一个列表,你可以从中获得很好的性能提升。
import re
p = re.compile(r'[^\w\s]+')
df['text'] = [p.sub('', x) for x in df['text'].tolist()]
df
text
0 ab
1 hgh12
2 abc123
3 1234
注意:如果您的数据具有NaN值,则此(以及下面的下一个方法)将无法正常工作。请参阅“其他注意事项”部分。
str.translate
python的str.translate
函数在C中实现,因此非常快。
这是如何工作的:
str.translate
,删除标点符号(不包括第1步中的分隔符)。在这个例子中,我们考虑管道分隔符|
。如果您的数据包含管道,则必须选择另一个分隔符。
import string
punct = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{}~' # `|` is not present here
transtab = str.maketrans(dict.fromkeys(punct, ''))
df['text'] = '|'.join(df['text'].tolist()).translate(transtab).split('|')
df
text
0 ab
1 hgh12
2 abc123
3 1234
str.translate
表现最好。请注意,下图包含来自MaxU's answer的其他变体Series.str.translate
。
(有趣的是,我第二次重申这一点,结果与之前略有不同。在第二次运行期间,似乎re.sub
在str.translate
上赢得了非常少量的数据。 )
使用translate
存在固有的风险(尤其是自动化决定使用哪个分隔符的过程非常重要),但权衡取舍值得冒风险。
使用列表推导方法处理NaN; 请注意,只有您的数据没有NaN时,此方法(以及下一个方法)才有效。处理NaN时,您必须确定非空值的索引并仅替换它们。尝试这样的事情:
df = pd.DataFrame({'text': [
'a..b?!??', np.nan, '%hgh&12','abc123!!!', '$$$1234', np.nan]})
idx = np.flatnonzero(df['text'].notna())
col_idx = df.columns.get_loc('text')
df.iloc[idx,col_idx] = [
p.sub('', x) for x in df.iloc[idx,col_idx].tolist()]
df
text
0 ab
1 NaN
2 hgh12
3 abc123
4 1234
5 NaN
处理DataFrames; 如果您正在处理DataFrames,每个列都需要替换,那么过程很简单:
v = pd.Series(df.values.ravel())
df[:] = translate(v).values.reshape(df.shape)
或者,
v = df.stack()
v[:] = translate(v)
df = v.unstack()
请注意,下面使用基准测试代码定义translate
函数。
每个解决方案都有权衡,因此决定最适合您需求的解决方案将取决于您愿意牺牲的内容。两个非常常见的考虑因素是性能(我们已经看到)和内存使用情况。 str.translate
是一个需要大量内存的解决方案,因此请谨慎使用。
另一个考虑因素是你的正则表达式的复杂性。有时,您可能希望删除任何非字母数字或空格的内容。在Othertimes中,您需要保留某些字符,例如连字符,冒号和句子终止符[.!?]
。指定这些明确增加了正则表达式的复杂性,这可能反过来影响这些解决方案的性能。确保测试这些解决方案
在决定使用什么之前对你的数据进行处理。
最后,使用此解决方案将删除unicode字符。您可能想要调整正则表达式(如果使用基于正则表达式的解决方案),或者只是使用str.translate
。
对于更多表现(对于较大的N),请按Paul Panzer查看此答案。
<强>功能强>
def pd_replace(df):
return df.assign(text=df['text'].str.replace(r'[^\w\s]+', ''))
def re_sub(df):
p = re.compile(r'[^\w\s]+')
return df.assign(text=[p.sub('', x) for x in df['text'].tolist()])
def translate(df):
punct = string.punctuation.replace('|', '')
transtab = str.maketrans(dict.fromkeys(punct, ''))
return df.assign(
text='|'.join(df['text'].tolist()).translate(transtab).split('|')
)
# MaxU's version (https://stackoverflow.com/a/50444659/4909087)
def pd_translate(df):
punct = string.punctuation.replace('|', '')
transtab = str.maketrans(dict.fromkeys(punct, ''))
return df.assign(text=df['text'].str.translate(transtab))
效果基准代码
from timeit import timeit
import pandas as pd
import matplotlib.pyplot as plt
res = pd.DataFrame(
index=['pd_replace', 're_sub', 'translate', 'pd_translate'],
columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000],
dtype=float
)
for f in res.index:
for c in res.columns:
l = ['a..b?!??', '%hgh&12','abc123!!!', '$$$1234'] * c
df = pd.DataFrame({'text' : l})
stmt = '{}(df)'.format(f)
setp = 'from __main__ import df, {}'.format(f)
res.at[f, c] = timeit(stmt, setp, number=30)
ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");
plt.show()
答案 1 :(得分:26)
使用numpy,我们可以获得迄今为止发布的最佳方法的健康加速。基本策略是相似的 - 制作一个超级大字符串。但是在numpy中处理似乎要快得多,大概是因为我们充分利用了无需替换的操作的简单性。
对于较小的(总共少于0x110000
个字符)问题,我们会自动找到分隔符,对于较大的问题,我们使用较慢的方法,而不依赖于str.split
。
请注意,我已将所有预可计算项移出函数。另请注意,translate
和pd_translate
了解免费的三大问题唯一可能的分隔符,而np_multi_strat
必须计算它或回退到无分隔符策略。最后,请注意,对于最后三个数据点,我切换到了一个更有趣的&#34;问题; pd_replace
和re_sub
,因为它们不等同于其他方法必须被排除在外。
基本策略实际上非常简单。只有0x110000
个不同的unicode字符。由于OP构成了巨大数据集的挑战,因此制作一个查询表非常值得我们想要保留的字符ID True
和False
必须去---我们的例子中的标点符号。
使用numpy的高级索引可以将这样的查找表用于批量连接。由于查找是完全向量化的,并且基本上相当于取消引用指针数组,因此它比例如字典查找要快得多。这里我们使用numpy视图转换,它允许将unicode字符重新解释为整数,基本上是免费的。
使用仅包含一个怪物字符串的数据数组,该字符串被重新解释为要编入查找表的数字序列,从而产生布尔掩码。然后可以使用此掩码过滤掉不需要的字符。使用布尔索引也是一行代码。
到目前为止这么简单。棘手的一点就是把怪物绳子砍回原来的部位。如果我们有一个分隔符,即一个字符在数据或标点符号列表中没有出现,那么它仍然很容易。使用此字符加入并重新分页。但是,自动查找分隔符很有挑战性,实际上占下面实现中loc的一半。
或者,我们可以将分割点保持在单独的数据结构中,跟踪它们如何移动以删除不需要的字符,然后使用它们来切片处理过的怪物字符串。由于切割成不均匀长度的部分不是numpy最强的套装,这种方法比str.split
慢,并且仅当分隔符太昂贵而无法计算它是否存在于第一位时用作后备
代码(根据@ COLDSPEED&#39;帖子计时/绘图很多):
import numpy as np
import pandas as pd
import string
import re
spct = np.array([string.punctuation]).view(np.int32)
lookup = np.zeros((0x110000,), dtype=bool)
lookup[spct] = True
invlookup = ~lookup
OSEP = spct[0]
SEP = chr(OSEP)
while SEP in string.punctuation:
OSEP = np.random.randint(0, 0x110000)
SEP = chr(OSEP)
def find_sep_2(letters):
letters = np.array([letters]).view(np.int32)
msk = invlookup.copy()
msk[letters] = False
sep = msk.argmax()
if not msk[sep]:
return None
return sep
def find_sep(letters, sep=0x88000):
letters = np.array([letters]).view(np.int32)
cmp = np.sign(sep-letters)
cmpf = np.sign(sep-spct)
if cmp.sum() + cmpf.sum() >= 1:
left, right, gs = sep+1, 0x110000, -1
else:
left, right, gs = 0, sep, 1
idx, = np.where(cmp == gs)
idxf, = np.where(cmpf == gs)
sep = (left + right) // 2
while True:
cmp = np.sign(sep-letters[idx])
cmpf = np.sign(sep-spct[idxf])
if cmp.all() and cmpf.all():
return sep
if cmp.sum() + cmpf.sum() >= (left & 1 == right & 1):
left, sep, gs = sep+1, (right + sep) // 2, -1
else:
right, sep, gs = sep, (left + sep) // 2, 1
idx = idx[cmp == gs]
idxf = idxf[cmpf == gs]
def np_multi_strat(df):
L = df['text'].tolist()
all_ = ''.join(L)
sep = 0x088000
if chr(sep) in all_: # very unlikely ...
if len(all_) >= 0x110000: # fall back to separator-less method
# (finding separator too expensive)
LL = np.array((0, *map(len, L)))
LLL = LL.cumsum()
all_ = np.array([all_]).view(np.int32)
pnct = invlookup[all_]
NL = np.add.reduceat(pnct, LLL[:-1])
NLL = np.concatenate([[0], NL.cumsum()]).tolist()
all_ = all_[pnct]
all_ = all_.view(f'U{all_.size}').item(0)
return df.assign(text=[all_[NLL[i]:NLL[i+1]]
for i in range(len(NLL)-1)])
elif len(all_) >= 0x22000: # use mask
sep = find_sep_2(all_)
else: # use bisection
sep = find_sep(all_)
all_ = np.array([chr(sep).join(L)]).view(np.int32)
pnct = invlookup[all_]
all_ = all_[pnct]
all_ = all_.view(f'U{all_.size}').item(0)
return df.assign(text=all_.split(chr(sep)))
def pd_replace(df):
return df.assign(text=df['text'].str.replace(r'[^\w\s]+', ''))
p = re.compile(r'[^\w\s]+')
def re_sub(df):
return df.assign(text=[p.sub('', x) for x in df['text'].tolist()])
punct = string.punctuation.replace(SEP, '')
transtab = str.maketrans(dict.fromkeys(punct, ''))
def translate(df):
return df.assign(
text=SEP.join(df['text'].tolist()).translate(transtab).split(SEP)
)
# MaxU's version (https://stackoverflow.com/a/50444659/4909087)
def pd_translate(df):
return df.assign(text=df['text'].str.translate(transtab))
from timeit import timeit
import pandas as pd
import matplotlib.pyplot as plt
res = pd.DataFrame(
index=['translate', 'pd_replace', 're_sub', 'pd_translate', 'np_multi_strat'],
columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000,
1000000],
dtype=float
)
for c in res.columns:
if c >= 100000: # stress test the separator finder
all_ = np.r_[:OSEP, OSEP+1:0x110000].repeat(c//10000)
np.random.shuffle(all_)
split = np.arange(c-1) + \
np.sort(np.random.randint(0, len(all_) - c + 2, (c-1,)))
l = [x.view(f'U{x.size}').item(0) for x in np.split(all_, split)]
else:
l = ['a..b?!??', '%hgh&12','abc123!!!', '$$$1234'] * c
df = pd.DataFrame({'text' : l})
for f in res.index:
if f == res.index[0]:
ref = globals()[f](df).text
elif not (ref == globals()[f](df).text).all():
res.at[f, c] = np.nan
print(f, 'disagrees at', c)
continue
stmt = '{}(df)'.format(f)
setp = 'from __main__ import df, {}'.format(f)
res.at[f, c] = timeit(stmt, setp, number=16)
ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");
plt.show()
答案 2 :(得分:18)
有趣的是,与Vanilla Python str.translate()
相比,矢量化Series.str.translate方法仍然稍慢:
def pd_translate(df):
return df.assign(text=df['text'].str.translate(transtab))