如何在pyspark的高基数分类栏中有效地对低频率计数进行分组?

时间:2018-02-01 16:04:13

标签: python string apache-spark pyspark spark-dataframe

我目前正在尝试找到在StringType()列中出现次数较少的分类列中对级别进行分组的有效方法。我想基于百分比阈值执行此操作,即替换少于z%行的所有值。此外,重要的是我们可以返回数值之间的映射(在应用StringIndexer之后)和原始值。

基本上,阈值为25%,这个数据帧:

+---+---+---+---+
| x1| x2| x3| x4|
+---+---+---+---+
|  a|  a|  a|  a|
|  b|  b|  a|  b|
|  a|  a|  a|  c|
|  b|  b|  a|  d|
|  c|  a|  a|  e|
+---+---+---+---+

应该成为这个:

+------+------+------+------+
|x1_new|x2_new|x3_new|x4_new|
+------+------+------+------+
|     a|     a|     a| other|
|     b|     b|     a| other|
|     a|     a|     a| other|
|     b|     b|     a| other|
| other|     a|     a| other|
+------+------+------+------+

其中c已替换为other列中的x1,并且所有值已替换为other列中的x4,因为它们出现在小于25%行。

我希望使用常规StringIndexer,并利用价值根据其频率排序的事实。我们可以计算保留多少个值,并用例如其他值替换所有其他值。 -1。这种方法的问题:这会在IndexToString之后引发错误,我假设因为元数据丢失了。

我的问题;有一个很好的方法来做到这一点?我可能会忽略内置功能吗?有没有办法保留元数据?

提前致谢!

df = pd.DataFrame({'x1' : ['a','b','a','b','c'],  # a: 0.4, b: 0.4, c: 0.2
                   'x2' : ['a','b','a','b','a'],  # a: 0.6, b: 0.4, c: 0.0
                   'x3' : ['a','a','a','a','a'],  # a: 1.0, b: 0.0, c: 0.0
                   'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)

1 个答案:

答案 0 :(得分:2)

我做了一些进一步的调查,偶然发现了this post关于向pyspark中的列添加元数据的问题。基于此,我能够创建一个名为group_low_freq的函数,我认为它非常有效;它仅使用StringIndexer一次,然后修改此列和元数据,以便在名为"其他"的单独组中将所有小于x%的元素分区。由于我们还修改了元数据,因此我们可以稍后在IndexToString上检索字符串。功能和示例如下:

代码:

import findspark
findspark.init()
import pyspark as ps
from pyspark.sql import SQLContext, Column
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, count as sparkcount, when, lit
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml import Pipeline
import json 

try:
    sc
except NameError:
    sc = ps.SparkContext()
    sqlContext = SQLContext(sc)

from pyspark.sql.functions import col

def withMeta(self, alias, meta):
    sc = ps.SparkContext._active_spark_context
    jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
    return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))

def group_low_freq(df,inColumns,threshold=.01,group_text='other'):
    """
    Index string columns and group all observations that occur in less then a threshold% of the rows in df per column.
    :param df: A pyspark.sql.dataframe.DataFrame
    :param inColumns: String columns that need to be indexed
    :param group_text: String to use as replacement for the observations that need to be grouped.
    """
    total = df.count()
    for string_col in inColumns:
        # Apply string indexer
        pipeline = Pipeline(stages=[StringIndexer(inputCol=string_col, outputCol="ix_"+string_col)])
        df = pipeline.fit(df).transform(df)

        # Calculate the number of unique elements to keep
        n_to_keep = df.groupby(string_col).agg((sparkcount(string_col)/total).alias('perc')).filter(col('perc')>threshold).count()

        # If elements occur below (threshold * number of rows), replace them with n_to_keep.
        this_meta = df.select('ix_' + string_col).schema.fields[0].metadata
        if n_to_keep != len(this_meta['ml_attr']['vals']):  
            this_meta['ml_attr']['vals'] = this_meta['ml_attr']['vals'][0:(n_to_keep+1)]
            this_meta['ml_attr']['vals'][n_to_keep] = group_text    
            df = df.withColumn('ix_'+string_col,when(col('ix_'+string_col)>=n_to_keep,lit(n_to_keep)).otherwise(col('ix_'+string_col)))

        # add the new column with correct metadata, remove original.
        df = df.withColumn('ix_'+string_col, withMeta(col('ix_'+string_col), "", this_meta))

    return df




# SAMPLE DATA -----------------------------------------------------------------

df = pd.DataFrame({'x1' : ['a','b','a','b','c'],  # a: 0.4, b: 0.4, c: 0.2
                   'x2' : ['a','b','a','b','a'],  # a: 0.6, b: 0.4, c: 0.0
                   'x3' : ['a','a','a','a','a'],  # a: 1.0, b: 0.0, c: 0.0
                   'x4' : ['a','b','c','d','e']}) # a: 0.2, b: 0.2, c: 0.2, d: 0.2, e: 0.2
df = sqlContext.createDataFrame(df)

# TEST THE FUNCTION -----------------------------------------------------------

df = group_low_freq(df,df.columns,0.25)    

ix_cols = [x for x in df.columns if 'ix_' in x]
for string_col in ix_cols:    
    idx_to_string = IndexToString(inputCol=string_col, outputCol=string_col[3:]+'grouped')
    df = idx_to_string.transform(df)

df.show()

输出的阈值为25%(因此每组必须至少出现25%的行):

    +---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
    | x1| x2| x3| x4|ix_x1|ix_x2|ix_x3|ix_x4|x1grouped|x2grouped|x3grouped|x4grouped|
    +---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+
    |  a|  a|  a|  a|  0.0|  0.0|  0.0|  0.0|        a|        a|        a|    other|
    |  b|  b|  a|  b|  1.0|  1.0|  0.0|  0.0|        b|        b|        a|    other|
    |  a|  a|  a|  c|  0.0|  0.0|  0.0|  0.0|        a|        a|        a|    other|
    |  b|  b|  a|  d|  1.0|  1.0|  0.0|  0.0|        b|        b|        a|    other|
    |  c|  a|  a|  e|  2.0|  0.0|  0.0|  0.0|    other|        a|        a|    other|
    +---+---+---+---+-----+-----+-----+-----+---------+---------+---------+---------+