如何在PySpark中找到数组的平均值

时间:2019-12-10 01:20:06

标签: python pandas pyspark pyspark-sql pyspark-dataframes

我有一个PySpark数据框,其中的一列(例如B)是一个数组数组。以下是PySpark数据框:

+---+-----------------------------+---+
|A  |B                            |C  |
+---+-----------------------------+---+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |
|a  |[[5.0], [20.0, 80.0]]        |d  |
|a  |[[5.0], [25.0, 75.0]]        |e  |
|b  |[[5.0], [25.0, 75.0]]        |f  |
|b  |[[5.0], [12.0, 88.0]]        |g  |
+---+-----------------------------+---+

我想找到每一行的元素数量和所有元素的平均值(作为单独的列)。

以下是预期的输出:

+---+-----------------------------+---+---+------+
|A  |B                            |C  |Num|   Avg|
+---+-----------------------------+---+---+------+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |4  | 23.75|
|a  |[[5.0], [20.0, 80.0]]        |d  |3  | 35.00|
|a  |[[5.0], [25.0, 75.0]]        |e  |3  | 35.00|
|b  |[[5.0], [25.0, 75.0]]        |f  |3  | 35.00|
|b  |[[5.0], [12.0, 88.0]]        |g  |3  | 35.00|
+---+-----------------------------+---+---+------+

在PySpark中查找数组(每行)中所有元素的平均值的有效方法是什么?

目前,我正在使用udf来执行这些操作。下面是我目前拥有的代码:

from pyspark.sql import functions as F
import pyspark.sql.types as T
from pyspark.sql import *
from pyspark.sql.types import DecimalType
from pyspark.sql.functions import udf
import numpy as np

#UDF to find number of elements
def len_array_of_arrays(anomaly_in_issue_group_col):
    return sum([len(array_element) for array_element in anomaly_in_issue_group_col])

udf_len_array_of_arrays = F.udf( len_array_of_arrays , T.IntegerType() )

#UDF to find average of all elements
def avg_array_of_arrays(anomaly_in_issue_group_col):
    return np.mean( [ element for array_element in anomaly_in_issue_group_col for element in array_element] )

udf_avg_array_of_arrays = F.udf( avg_array_of_arrays , T.DecimalType() )

df.withColumn("Num", udf_len_array_of_arrays(F.col("B"))).withColumn(
    "Avg", udf_avg_array_of_arrays(F.col("B"))
).show(20, False)

用于查找每一行中元素数量的udf起作用。但是,用于查找平均值的udf会引发以下错误:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-176-3253feca2963> in <module>()
      1 #df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).show(20, False)
----> 2 df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).withColumn("Avg" ,  udf_avg_array_of_arrays(F.col("B")) ).show(20, False)

/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    378             print(self._jdf.showString(n, 20, vertical))
    379         else:
--> 380             print(self._jdf.showString(n, int(truncate), vertical))
    381 
    382     def __repr__(self):

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

2 个答案:

答案 0 :(得分:2)

对于spark 2.4+,请使用flatten + aggregate

from pyspark.sql.functions import expr

df.withColumn("Avg", expr("""
    aggregate(
        flatten(B)
      , (double(0) as total, int(0) as cnt)
      , (x,y) -> (x.total+y, x.cnt+1)
      , z -> round(z.total/z.cnt,2)
    ) 
 """)).show()
+-----------------------------+---+-----+
|B                            |C  |Avg  |
+-----------------------------+---+-----+
|[[5.0], [25.0, 25.0], [40.0]]|c  |23.75|
|[[5.0], [25.0, 80.0]]        |d  |36.67|
|[[5.0], [25.0, 75.0]]        |e  |35.0 |
+-----------------------------+---+-----+

答案 1 :(得分:1)

自Spark 1.4起

explode()包含数组的列,其数量与嵌套级别相同。使用monotonically_increasing_id()创建一个额外的分组键,以防止重复的行被合并:

from pyspark.sql.functions import explode, sum, lit, avg, monotonically_increasing_id

df = spark.createDataFrame(
    (("a", [[1], [2, 3], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),  # DUPE!
     ("b", [[2, 3], [4]], "foo")),
    schema=("category", "arrays", "foo"))

df2 = (df.withColumn("id", monotonically_increasing_id())
       .withColumn("subarray", explode("arrays"))
       .withColumn("subarray", explode("subarray"))  # unnest another level
       .groupBy("category", "arrays", "foo", "id")
       .agg(sum(lit(1)).alias("number_of_elements"),
            avg("subarray").alias("avg")).drop("id"))
df2.show()
# +--------+------------------+---+------------------+----+  
# |category|            arrays|foo|number_of_elements| avg|
# +--------+------------------+---+------------------+----+
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       b|     [[2, 3], [4]]|foo|                 3| 3.0|
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       a|[[1], [2, 3], [4]]|foo|                 4| 2.5|
# +--------+------------------+---+------------------+----+

Spark 2.4 引入了24个处理复杂类型的函数,以及高阶函数(将函数作为参数的函数,例如Python 3的functools.reduce)。他们带走了您在上面看到的样板。如果您使用的是Spark2.4 +,请参见answer from jxc