如何在PySpark中按组统计类别变量?

时间:2018-08-25 18:36:57

标签: python apache-spark pyspark pyspark-sql

我可以运行以下代码并获取包含的输出,但是如果同一AlertType对于SessionID多次出现,它将不起作用。在这种情况下,我需要一种方法来在输出的OHE列中获取非1.0的值。该错误与迭代器有关。

我从以下问题和解答中获得了一些帮助:How to add sparse vectors after group by, using Spark SQL?

columns=['SessionID','AlertType']
vals=[
    (1,0),
    (1,1),
    (1,2),
    (1,3),
    (1,4),
    (2,0),
    (2,1),
    (2,2),
    (2,3),
    (2,4),
]

df=spark.createDataFrame(vals,columns)
df.show()

+---------+---------+
|SessionID|AlertType|
+---------+---------+
|        1|        0|
|        1|        1|
|        1|        2|
|        1|        3|
|        1|        4|
|        2|        0|
|        2|        1|
|        2|        2|
|        2|        3|
|        2|        4|
+---------+---------+

from pyspark.sql.functions import collect_list,max,lit, udf
from pyspark.ml.linalg import Vectors,VectorUDT

def encode(arr,length):
    vec_args=length,[(x,1.0) for x in arr]
    return Vectors.sparse(*vec_args)
encode_udf=udf(encode,VectorUDT())

# do stringindexer stuff
from pyspark.ml.feature import StringIndexer
indexer=StringIndexer(inputCol='AlertType',outputCol='AlertTypeStrIndexed').fit(df)
df_strIndexed=indexer.transform(df)
df_strIndexed.show()

+---------+---------+-------------------+
|SessionID|AlertType|AlertTypeStrIndexed|
+---------+---------+-------------------+
|        1|        0|                2.0|
|        1|        1|                1.0|
|        1|        2|                3.0|
|        1|        3|                4.0|
|        1|        4|                0.0|
|        2|        0|                2.0|
|        2|        1|                1.0|
|        2|        2|                3.0|
|        2|        3|                4.0|
|        2|        4|                0.0|
+---------+---------+-------------------+

df_strIndexed.agg(max(df_strIndexed["AlertTypeStrIndexed"])).show()
feats = df_strIndexed.agg(max(df_strIndexed["AlertTypeStrIndexed"])).take(1)[0][0] + 1

df_OHE_grouped=df_strIndexed.groupBy("SessionID") \
               .agg(collect_list("AlertTypeStrIndexed")
               .alias("AlertArray")) \
               .select("SessionID", encode_udf("AlertArray", lit(feats)) \
                       .alias("OHE")).show(truncate=False)

+---------+-------------------------------------+
|SessionID|OHE                                  |
+---------+-------------------------------------+
|1        |(5,[0,1,2,3,4],[1.0,1.0,1.0,1.0,1.0])|
|2        |(5,[0,1,2,3,4],[1.0,1.0,1.0,1.0,1.0])|

0 个答案:

没有答案