求pyspark数组的均值<double>

时间:2019-04-03 19:05:09

标签: apache-spark pyspark apache-spark-sql

在pyspark中,我有一个可变长度的双精度数组,我想找到其均值。但是,平均值函数需要单个数字类型。

有没有一种方法可以找到一个数组的平均值而不爆炸该数组?我有几个不同的数组,我希望能够执行以下操作:

df.select(col("Segment.Points.trajectory_points.longitude"))

DataFrame [经度:数组]

df.select(avg(col("Segment.Points.trajectory_points.longitude"))).show()
org.apache.spark.sql.AnalysisException: cannot resolve
'avg(Segment.Points.trajectory_points.longitude)' due to data type
mismatch: function average requires numeric types, not
ArrayType(DoubleType,true);;

如果我有3个具有以下数组的唯一记录,我希望将这些值的平均值作为输出。这将是3个平均经度值。

输入:

[Row(longitude=[-80.9, -82.9]),
 Row(longitude=[-82.92, -82.93, -82.94, -82.96, -82.92, -82.92]),
 Row(longitude=[-82.93, -82.93])]

输出:

-81.9,
-82.931,
-82.93

我正在使用Spark版本2.1.3。


爆炸解决方案:

所以我通过爆炸使它起作用,但是我希望避免这一步。这就是我所做的

from pyspark.sql.functions import col
import pyspark.sql.functions as F

longitude_exp = df.select(
    col("ID"), 
    F.posexplode("Segment.Points.trajectory_points.longitude").alias("pos", "longitude")
)

longitude_reduced = long_exp.groupBy("ID").agg(avg("longitude"))

这成功地取了意思。但是,由于我将在几列中执行此操作,因此必须将同一DF爆炸几次。我将继续研究它,以找到一种更干净的方法来实现此目的。

2 个答案:

答案 0 :(得分:2)

在最新的Spark版本(2.4或更高版本)中,最有效的解决方案是使用aggregate高阶函数:

from pyspark.sql.functions import expr

query = """aggregate(
    `{col}`,
    CAST(0.0 AS double),
    (acc, x) -> acc + x,
    acc -> acc / size(`{col}`)
) AS  `avg_{col}`""".format(col="longitude")

df.selectExpr("*", query).show()
+--------------------+------------------+
|           longitude|     avg_longitude|
+--------------------+------------------+
|      [-80.9, -82.9]|             -81.9|
|[-82.92, -82.93, ...|-82.93166666666667|
|    [-82.93, -82.93]|            -82.93|
+--------------------+------------------+

另请参阅Spark Scala row-wise average by handling null

答案 1 :(得分:1)

根据您的情况,您可以选择使用explodeudf。您已经注意到,explode不必要地昂贵。因此,udf是必经之路。

您可以编写自己的函数以获取数字列表的均值,也可以仅抄送numpy.mean。如果使用numpy.mean,则必须将结果强制转换为float(因为spark不知道如何处理numpy.float64 s)。

import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

array_mean = udf(lambda x: float(np.mean(x)), FloatType())
df.select(array_mean("longitude").alias("avg")).show()
#+---------+
#|      avg|
#+---------+
#|    -81.9|
#|-82.93166|
#|   -82.93|
#+---------+