我正在使用CountVectorizer
为ML准备数据集。我想过滤掉稀有单词,为此使用了CountVectorizer
,minDF或minTF参数。我还想删除在数据集中“经常”出现的项目。我看不到可以设置的maxTF或maxDF参数。有什么好方法吗?
df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])
因此,在这种情况下,如果我想删除出现“ 4”次或40%的时间以及出现两次或更少的参数的参数。这将删除“ b”和“ c”。
当前,我为下限要求运行CountVectorizer(minDf=3......)
。如何过滤出比我想建模的对象更频繁出现的项目。
答案 0 :(得分:0)
我想您要求提供CountVectorizer参数,但直到现在似乎没有该参数。这不是简单实现的简单或实用方法,但它可以工作。希望对您有所帮助:
from pyspark.sql.types import *
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])
counts_df = df \
.select(F.explode('raw').alias('testCol')) \
.groupby('testCol') \
.agg(F.count('testCol').alias('count')).persist() # this will be used multiple times
total = counts_df \
.agg(F.sum('count').alias('total')) \
.rdd.take(1)[0]['total']
min_times = 3
max_times = total * 0.4
filtered_elements = counts_df \
.filter((min_times>F.col('count')) | (F.col('count')>max_times)) \
.select('testCol') \
.rdd.map(lambda row: row['testCol']) \
.collect()
def removeElements(arr):
return list(set(arr) - set(filtered_elements))
remove_udf = F.udf(removeElements, ArrayType(StringType()))
filtered_df = df \
.withColumn('raw', remove_udf('raw'))
结果:
filtered_df.show()
+-----+---+
|label|raw|
+-----+---+
| 0|[a]|
| 1|[a]|
+-----+---+