我可以运行以下代码并获取包含的输出,但是如果同一AlertType对于SessionID多次出现,它将不起作用。在这种情况下,我需要一种方法来在输出的OHE列中获取非1.0的值。该错误与迭代器有关。
我从以下问题和解答中获得了一些帮助:How to add sparse vectors after group by, using Spark SQL?
columns=['SessionID','AlertType']
vals=[
(1,0),
(1,1),
(1,2),
(1,3),
(1,4),
(2,0),
(2,1),
(2,2),
(2,3),
(2,4),
]
df=spark.createDataFrame(vals,columns)
df.show()
+---------+---------+
|SessionID|AlertType|
+---------+---------+
| 1| 0|
| 1| 1|
| 1| 2|
| 1| 3|
| 1| 4|
| 2| 0|
| 2| 1|
| 2| 2|
| 2| 3|
| 2| 4|
+---------+---------+
from pyspark.sql.functions import collect_list,max,lit, udf
from pyspark.ml.linalg import Vectors,VectorUDT
def encode(arr,length):
vec_args=length,[(x,1.0) for x in arr]
return Vectors.sparse(*vec_args)
encode_udf=udf(encode,VectorUDT())
# do stringindexer stuff
from pyspark.ml.feature import StringIndexer
indexer=StringIndexer(inputCol='AlertType',outputCol='AlertTypeStrIndexed').fit(df)
df_strIndexed=indexer.transform(df)
df_strIndexed.show()
+---------+---------+-------------------+
|SessionID|AlertType|AlertTypeStrIndexed|
+---------+---------+-------------------+
| 1| 0| 2.0|
| 1| 1| 1.0|
| 1| 2| 3.0|
| 1| 3| 4.0|
| 1| 4| 0.0|
| 2| 0| 2.0|
| 2| 1| 1.0|
| 2| 2| 3.0|
| 2| 3| 4.0|
| 2| 4| 0.0|
+---------+---------+-------------------+
df_strIndexed.agg(max(df_strIndexed["AlertTypeStrIndexed"])).show()
feats = df_strIndexed.agg(max(df_strIndexed["AlertTypeStrIndexed"])).take(1)[0][0] + 1
df_OHE_grouped=df_strIndexed.groupBy("SessionID") \
.agg(collect_list("AlertTypeStrIndexed")
.alias("AlertArray")) \
.select("SessionID", encode_udf("AlertArray", lit(feats)) \
.alias("OHE")).show(truncate=False)
+---------+-------------------------------------+
|SessionID|OHE |
+---------+-------------------------------------+
|1 |(5,[0,1,2,3,4],[1.0,1.0,1.0,1.0,1.0])|
|2 |(5,[0,1,2,3,4],[1.0,1.0,1.0,1.0,1.0])|