在pyspark列中的列表上应用功能

时间:2018-07-11 11:06:19

标签: apache-spark pyspark apache-spark-sql user-defined-functions

>> df = hc.createDataFrame([('a', [1.0, 1.0]), ('a',[1.0, 0.2,0.3,0.7]), ('b', [1.0]),('c' ,[1.0, 0.5]), ('d', [0.55, 1.0,1.4]),('e', [1.05, 1.0])])


>> df.show()
+---+--------------------+
| _1|                  _2|
+---+--------------------+
|  a|          [1.0, 1.0]|
|  a|[1.0, 0.2, 0.3, 0.7]|
|  b|               [1.0]|
|  c|          [1.0, 0.5]|
|  d|    [0.55, 1.0, 1.4]|
|  e|         [1.05, 1.0]|
+---+--------------------+

现在,我想在“ _2”列上应用求和或均值之类的函数来创建“ _3”列 例如,我使用sum函数创建了一个列 结果应如下所示

+---+--------------------+----+
| _1|                  _2|  _3|
+---+--------------------+----+
|  a|          [1.0, 1.0]| 2.0|
|  a|[1.0, 0.2, 0.3, 0.7]| 2.2|
|  b|               [1.0]| 1.0|
|  c|          [1.0, 0.5]| 1.5|
|  d|    [0.55, 1.0, 1.4]|2.95|
|  e|         [1.05, 1.0]|2.05|
+---+--------------------+----+

预先感谢

1 个答案:

答案 0 :(得分:2)

TL; DR 除非您使用proprietary extensions,否则必须为每个操作定义一个UserDefinedFunction

from pyspark.sql.functions import udf
import numpy as np

@udf("double")
def array_sum(xs):
    return np.sum(xs).tolist() if xs is not None else None

@udf("double")
def array_mean(xs):
    return np.mean(xs).tolist() if xs is not None else None

(df
    .withColumn("mean", array_mean("_2"))
    .withColumn("sum", array_sum("_2")))

在某些情况下,您可能更愿意explode进行聚合,但是除非应用程序限制了它的使用,而且通常要贵得多,除非数据已经被唯一标识符分区了。

from pyspark.sql.functions import monotonically_increasing_id, first, mean, sum, explode

(df
    .withColumn("_id", monotonically_increasing_id()).withColumn("x", explode("_2"))
    .groupBy("_id")
    .agg(first("_1"), first("_2"), mean("x"), sum("x")))