扁平化pyspark中sql.dataframe.DataFrame的数组的数组(不同维度)

时间:2019-12-09 19:34:50

标签: pyspark

我有一个pyspark.sql.dataframe.DataFrame,它是这样的:

+---------------------------+--------------------+--------------------+
|collect_list(results)      |        userid      |         page       |
+---------------------------+--------------------+--------------------+
|       [[[roundtrip, fal...|13482f06-9185-47f...|1429d15b-91d0-44b...|
+---------------------------+--------------------+--------------------+

collect_list(results)列中有一个len = 2的数组,元素也是数组(第一个数组的len = 1,第二个数组的len = 9)。

是否有一种方法可以使用pyspark将这个数组数组展平为len = 10的唯一数组?

谢谢!

2 个答案:

答案 0 :(得分:2)

您可以使用pyspark.sql.functions.flatten展平数组。文档here。例如,这将创建一个名为results的新列,假设您的数据帧变量名为df,结果将变平。

import pyspark.sql.functions as F
...
df.withColumn('results', F.flatten('collect_list(results)')

答案 1 :(得分:1)

对于在Spark 2.4之前(但在1.3之前)运行的版本,您可以尝试groupBy在分组之前获得的数据集,从而取消嵌套数组的一级,然后调用collect_listfrom pyspark.sql.functions import collect_list, explode df = spark.createDataFrame([("foo", [1,]), ("foo", [2, 3])], schema=("foo", "bar")) df.show() # +---+------+ # |foo| bar| # +---+------+ # |foo| [1]| # |foo|[2, 3]| # +---+------+ (df.select( df.foo, explode(df.bar)) .groupBy("foo") .agg(collect_list("col")) .show()) # +---+-----------------+ # |foo|collect_list(col)| # +---+-----------------+ # |foo| [1, 2, 3]| # +---+-----------------+ 。像这样:

{{1}}
相关问题