我有一个数组列表,我需要为其列表中的每个元素找到频率最高的元素。对于以下代码,抛出“ unhashable type:'list'”错误。但是,我也尝试并行化结果列表但是错误仍然存在。
# [array(0,1,1),array(0,0,1),array(1,1,0)] example of list
def finalml(listn):
return Counter(listn).most_common(1)
# the array list is return by this
results = sn.rdd.map(lambda xw: bc_knnobj.value.kneighbors(xw, return_distance=False)).collect()
labels = results.map(lambda xw: finalml(xw)).collect()
预期输出 [1,0,1]
答案 0 :(得分:1)
尝试一下:
x = [[0,1,1],[0,0,1],[1,1,0]]
df = spark.createDataFrame(x)
df.show()
输入df:
+---+---+---+
| _1| _2| _3|
+---+---+---+
| 0| 1| 1|
| 0| 0| 1|
| 1| 1| 0|
+---+---+---+
import pyspark.sql.functions as F
@F.udf
def mode(x):
from collections import Counter
return Counter(x).most_common(1)[0][0]
cols = df.columns
agg_expr = [mode(F.collect_list(col)).alias(col) for col in cols]
df.groupBy().agg(*agg_expr).show()
输出df:
+---+---+---+
| _1| _2| _3|
+---+---+---+
| 0| 1| 1|
+---+---+---+