Pyspark数据框:计算数组或列表中的元素

时间:2018-09-28 07:27:36

标签: arrays list dataframe pyspark counting

让我们假设数据帧df为:

df.show()

输出:

+------+----------------+
|letter| list_of_numbers|
+------+----------------+
|     A|    [3, 1, 2, 3]|
|     B|    [1, 2, 1, 1]|
+------+----------------+

我想要做的是countlist_of_numbers中特定元素的编号。像这样:

+------+----------------+----+
|letter| list_of_numbers|ones|
+------+----------------+----+
|     A|    [3, 1, 2, 3]|   1|
|     B|    [1, 2, 1, 1]|   3|
+------+----------------+----+

到目前为止,我已经尝试过创建udf了,并且效果很好,但是我想知道是否可以在不定义任何udf的情况下做到这一点。

3 个答案:

答案 0 :(得分:3)

您可以爆炸数组并过滤1的爆炸值。然后groupBycount

from pyspark.sql.functions import col, count, explode

df.select("*", explode("list_of_numbers").alias("exploded"))\
    .where(col("exploded") == 1)\
    .groupBy("letter", "list_of_numbers")\
    .agg(count("exploded").alias("ones"))\
    .show()
#+------+---------------+----+
#|letter|list_of_numbers|ones|
#+------+---------------+----+
#|     A|   [3, 1, 2, 3]|   1|
#|     B|   [1, 2, 1, 1]|   3|
#+------+---------------+----+

为了保留所有行,即使计数为0,也可以将exploded列转换为指示符变量。然后是groupBysum

from pyspark.sql.functions import col, count, explode, sum as sum_

df.select("*", explode("list_of_numbers").alias("exploded"))\
    .withColumn("exploded", (col("exploded") == 1).cast("int"))\
    .groupBy("letter", "list_of_numbers")\
    .agg(sum_("exploded").alias("ones"))\
    .show()

请注意,我已将pyspark.sql.functions.sum导入为sum_,以免覆盖内置的sum函数。

答案 1 :(得分:1)

假设列表的长度是恒定的,我能想到的一种方法是

from operator import add
from functools import reduce
import pyspark.sql.functions as F

df = sql.createDataFrame(
    [
        ['A',[3, 1, 2, 3]],
        ['B',[1, 2, 1, 1]]
    ],      
        ['letter','list_of_numbers'])

expr = reduce(add,[F.when(F.col('list_of_numbers').getItem(x)==1, 1)\
                    .otherwise(0) for x in range(4)])
df = df.withColumn('ones', expr)
df.show()

+------+---------------+----+
|letter|list_of_numbers|ones|
+------+---------------+----+
|     A|   [3, 1, 2, 3]|   1|
|     B|   [1, 2, 1, 1]|   3|
+------+---------------+----+

答案 2 :(得分:0)

上面的Ala Tarighati评论说,该解决方案不适用于长度不同的数组。以下是可解决该问题的udf

from operator import add
from functools import reduce
import pyspark.sql.functions as F

df = sql.createDataFrame(
    [
        ['A',[3, 1, 2, 3]],
        ['B',[1, 2, 1, 1]]
    ],      
        ['letter','list_of_numbers'])

df_ones = (
    df.withColumn(
        'ones', 
        reduce(
            add,
            [
                F.when(
                    F.col("list_of_numbers").getItem(x) == F.lit("1"), 1
                ).otherwise(0)
                for x in range(len("drivers"))
            ],
        ),
    )
)
df_ones.show()
+------+---------------+----+
|letter|list_of_numbers|ones|
+------+---------------+----+
|     A|   [3, 1, 2, 3]|   1|
|     B|   [1, 2, 1, 1]|   3|
+------+---------------+----+