如何在PySpark中获取最少的嵌套列表

时间:2019-06-09 16:55:58

标签: apache-spark pyspark

例如,请参见以下数据框,

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('test').getOrCreate()
df = spark.createDataFrame([[[1, 2, 3, 4]],[[0, 2, 4]],[[]],[[3]]])
df.show()

那么我们有

+------------+
|          _1|
+------------+
|[1, 2, 3, 4]|
|   [0, 2, 4]|
|          []|
|         [3]|
+------------+

然后我想找到每个列表的最小值;如果列表为空,则使用-1。我尝试了以下方法,但这不起作用。

import pyspark.sql.functions as F
sim_col = F.col('_1')
df.withColumn('min_turn_sim', F.when(F.size(sim_col)==0, -1.0).otherwise(F.min(sim_col))).show()

错误是:

  

AnalysisException:由于数据类型不匹配,无法解析'((_1为NULL)THEN -1.0D ELSE min(_1)END'的情况:THEN和ELSE表达式都应为同一类型或强制为通用类型; \ n'聚合[_1#404,当为null(_1#404)THEN -1.0 ELSE min(_1#404)END AS min_turn_sim#411]时的情况\ n +-LogicalRDD [_1#404] ,false \ n“


大小功能将起作用。不明白为什么“ min”没有。

df.withColumn('min_turn_sim', F.when(F.size(sim_col)==0, -1.0).otherwise(F.size(sim_col))).show()

+------------+------------+
|          _1|min_turn_sim|
+------------+------------+
|[1, 2, 3, 4]|         4.0|
|   [0, 2, 4]|         3.0|
|          []|        -1.0|
|         [3]|         1.0|
+------------+------------+

1 个答案:

答案 0 :(得分:2)

min是一个聚合函数-它对列而不是值进行操作。因此,min(sim_col)表示根据数组顺序在作用域内所有行中的最小数组值,而不是每行中的最小值。

要找到每一行的最小值,您需要一个非聚合函数。在最新的Spark版本(2.4.0及更高版本)中,该值为array_min(类似于array_max以获取最大值):

df.withColumn("min_turn_sim", F.coalesce(F.array_min(sim_col), F.lit(-1)))

早期版本将需要UDF:

@F.udf("long")
def long_array_min(xs):
    return min(xs) if xs else -1

df.withColumn("min_turn_sim", F.coalesce(long_array_min(sim_col), F.lit(-1))