我有一个pyspark.sql.dataframe.DataFrame
,它是这样的:
+---------------------------+--------------------+--------------------+
|collect_list(results) | userid | page |
+---------------------------+--------------------+--------------------+
| [[[roundtrip, fal...|13482f06-9185-47f...|1429d15b-91d0-44b...|
+---------------------------+--------------------+--------------------+
collect_list(results)列中有一个len = 2的数组,元素也是数组(第一个数组的len = 1,第二个数组的len = 9)。
是否有一种方法可以使用pyspark将这个数组数组展平为len = 10的唯一数组?
谢谢!
答案 0 :(得分:2)
您可以使用pyspark.sql.functions.flatten
展平数组。文档here。例如,这将创建一个名为results
的新列,假设您的数据帧变量名为df
,结果将变平。
import pyspark.sql.functions as F
...
df.withColumn('results', F.flatten('collect_list(results)')
答案 1 :(得分:1)
对于在Spark 2.4之前(但在1.3之前)运行的版本,您可以尝试groupBy
在分组之前获得的数据集,从而取消嵌套数组的一级,然后调用collect_list
并from pyspark.sql.functions import collect_list, explode
df = spark.createDataFrame([("foo", [1,]), ("foo", [2, 3])], schema=("foo", "bar"))
df.show()
# +---+------+
# |foo| bar|
# +---+------+
# |foo| [1]|
# |foo|[2, 3]|
# +---+------+
(df.select(
df.foo,
explode(df.bar))
.groupBy("foo")
.agg(collect_list("col"))
.show())
# +---+-----------------+
# |foo|collect_list(col)|
# +---+-----------------+
# |foo| [1, 2, 3]|
# +---+-----------------+
。像这样:
{{1}}