Pyspark NLP-CountVectorizer最大DF或TF。如何从数据集中过滤常见事件

时间:2018-07-02 21:41:00

标签: python apache-spark pyspark nlp apache-spark-ml

我正在使用CountVectorizer为ML准备数据集。我想过滤掉稀有单词,为此使用了CountVectorizer,minDF或minTF参数。我还想删除在数据集中“经常”出现的项目。我看不到可以设置的maxTF或maxDF参数。有什么好方法吗?

df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])

因此,在这种情况下,如果我想删除出现“ 4”次或40%的时间以及出现两次或更少的参数的参数。这将删除“ b”和“ c”。

当前,我为下限要求运行CountVectorizer(minDf=3......)。如何过滤出比我想建模的对象更频繁出现的项目。

1 个答案:

答案 0 :(得分:0)

我想您要求提供CountVectorizer参数,但直到现在似乎没有该参数。这不是简单实现的简单或实用方法,但它可以工作。希望对您有所帮助:

from pyspark.sql.types import *
from pyspark.sql import functions as F

df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])

counts_df = df \
    .select(F.explode('raw').alias('testCol')) \
    .groupby('testCol') \
    .agg(F.count('testCol').alias('count')).persist() # this will be used multiple times

total = counts_df \
    .agg(F.sum('count').alias('total')) \
    .rdd.take(1)[0]['total']
min_times = 3
max_times = total * 0.4
filtered_elements = counts_df \
    .filter((min_times>F.col('count')) | (F.col('count')>max_times)) \
    .select('testCol') \
    .rdd.map(lambda row: row['testCol']) \
    .collect()

def removeElements(arr):
    return list(set(arr) - set(filtered_elements))

remove_udf = F.udf(removeElements, ArrayType(StringType()))
filtered_df = df \
    .withColumn('raw', remove_udf('raw'))

结果:

filtered_df.show()
+-----+---+
|label|raw|
+-----+---+
|    0|[a]|
|    1|[a]|
+-----+---+