让我们假设数据帧df
为:
df.show()
输出:
+------+----------------+
|letter| list_of_numbers|
+------+----------------+
| A| [3, 1, 2, 3]|
| B| [1, 2, 1, 1]|
+------+----------------+
我想要做的是count
列list_of_numbers
中特定元素的编号。像这样:
+------+----------------+----+
|letter| list_of_numbers|ones|
+------+----------------+----+
| A| [3, 1, 2, 3]| 1|
| B| [1, 2, 1, 1]| 3|
+------+----------------+----+
到目前为止,我已经尝试过创建udf
了,并且效果很好,但是我想知道是否可以在不定义任何udf
的情况下做到这一点。
答案 0 :(得分:3)
您可以爆炸数组并过滤1
的爆炸值。然后groupBy
和count
:
from pyspark.sql.functions import col, count, explode
df.select("*", explode("list_of_numbers").alias("exploded"))\
.where(col("exploded") == 1)\
.groupBy("letter", "list_of_numbers")\
.agg(count("exploded").alias("ones"))\
.show()
#+------+---------------+----+
#|letter|list_of_numbers|ones|
#+------+---------------+----+
#| A| [3, 1, 2, 3]| 1|
#| B| [1, 2, 1, 1]| 3|
#+------+---------------+----+
为了保留所有行,即使计数为0,也可以将exploded
列转换为指示符变量。然后是groupBy
和sum
。
from pyspark.sql.functions import col, count, explode, sum as sum_
df.select("*", explode("list_of_numbers").alias("exploded"))\
.withColumn("exploded", (col("exploded") == 1).cast("int"))\
.groupBy("letter", "list_of_numbers")\
.agg(sum_("exploded").alias("ones"))\
.show()
请注意,我已将pyspark.sql.functions.sum
导入为sum_
,以免覆盖内置的sum
函数。
答案 1 :(得分:1)
假设列表的长度是恒定的,我能想到的一种方法是
from operator import add
from functools import reduce
import pyspark.sql.functions as F
df = sql.createDataFrame(
[
['A',[3, 1, 2, 3]],
['B',[1, 2, 1, 1]]
],
['letter','list_of_numbers'])
expr = reduce(add,[F.when(F.col('list_of_numbers').getItem(x)==1, 1)\
.otherwise(0) for x in range(4)])
df = df.withColumn('ones', expr)
df.show()
+------+---------------+----+
|letter|list_of_numbers|ones|
+------+---------------+----+
| A| [3, 1, 2, 3]| 1|
| B| [1, 2, 1, 1]| 3|
+------+---------------+----+
答案 2 :(得分:0)
上面的Ala Tarighati评论说,该解决方案不适用于长度不同的数组。以下是可解决该问题的udf
from operator import add
from functools import reduce
import pyspark.sql.functions as F
df = sql.createDataFrame(
[
['A',[3, 1, 2, 3]],
['B',[1, 2, 1, 1]]
],
['letter','list_of_numbers'])
df_ones = (
df.withColumn(
'ones',
reduce(
add,
[
F.when(
F.col("list_of_numbers").getItem(x) == F.lit("1"), 1
).otherwise(0)
for x in range(len("drivers"))
],
),
)
)
df_ones.show()
+------+---------------+----+
|letter|list_of_numbers|ones|
+------+---------------+----+
| A| [3, 1, 2, 3]| 1|
| B| [1, 2, 1, 1]| 3|
+------+---------------+----+