我需要处理数据集以识别频繁的项目集。因此输入列必须是向量。原始列是一个字符串,其中的项目用逗号分隔,因此我执行了以下操作:
functions.split(out_1['skills'], ',')
问题是,对于某些行,我在skills
中有重复的值,这在尝试识别频繁项目集时会导致错误。
我想将矢量转换为一组以删除重复的元素。像这样:
functions.to_set(functions.split(out_1['skills'], ','))
但我找不到将列从vector转换为set的函数,即没有to_set
函数。
如何实现我想要的,即从矢量中删除重复的元素?
答案 0 :(得分:2)
您可以使用set
将python中的functions.udf(set)
函数转换为udf,然后将其应用于数组列:
df.show()
+-------+
| skills|
+-------+
|a,a,b,c|
| a,b,c|
|c,d,e,e|
+-------+
import pyspark.sql.functions as F
df.withColumn("unique_skills", F.udf(set)(F.split(df.skills, ","))).show()
+-------+-------------+
| skills|unique_skills|
+-------+-------------+
|a,a,b,c| [a, b, c]|
| a,b,c| [a, b, c]|
|c,d,e,e| [c, d, e]|
+-------+-------------+