pyspark。数据框

时间:2018-04-12 15:29:32

标签: apache-spark pyspark apache-spark-sql

我有以下PySpark DataFrame:

+------+----------------+
|    id|          data  |
+------+----------------+
|     1|    [10, 11, 12]|
|     2|    [20, 21, 22]|
|     3|    [30, 31, 32]|
+------+----------------+

最后,我想拥有以下DataFrame

+--------+----------------------------------+
|    id  |          data                    |
+--------+----------------------------------+
| [1,2,3]|[[10,20,30],[11,21,31],[12,22,32]]|
+--------+----------------------------------+

我命令这样做。首先,我提取数据数组如下:

tmp_array = df_test.select("data").rdd.flatMap(lambda x: x).collect()
a0 = tmp_array[0]
a1 = tmp_array[1]
a2 = tmp_array[2]
samples = zip(a0, a1, a2)
samples1 = sc.parallelize(samples)

通过这种方式,我在samples1中有一个内容为

的RDD
[[10,20,30],[11,21,31],[12,22,32]]
  • 问题1:这是一个好方法吗?

  • 问题2:如何将RDD包含回数据框?

2 个答案:

答案 0 :(得分:2)

您可以简单地使用udf函数作为zip函数,但在此之前您必须使用collect_list函数

from pyspark.sql import functions as f
from pyspark.sql import types as t
def zipUdf(array):
    return zip(*array)

zipping = f.udf(zipUdf, t.ArrayType(t.ArrayType(t.IntegerType())))

df.select(
    f.collect_list(df.id).alias('id'), 
    zipping(f.collect_list(df.data)).alias('data')
).show(truncate=False)

会给你

+---------+------------------------------------------------------------------------------+
|id       |data                                                                          |
+---------+------------------------------------------------------------------------------+
|[1, 2, 3]|[WrappedArray(10, 20, 30), WrappedArray(11, 21, 31), WrappedArray(12, 22, 32)]|
+---------+------------------------------------------------------------------------------+

答案 1 :(得分:2)

以下是一种获取所需输出的方法,无需序列化为rdd或使用udf。您将需要两个常量:

  • 您的DataFrame中的行数(df.count()
  • 数据长度(给定)

在双重列表理解中使用pyspark.sql.functions.collect_list()pyspark.sql.functions.array(),以pyspark.sql.Column.getItem()的顺序选择"data"的元素:

import pyspark.sql.functions as f
dataLength = 3
numRows = df.count()
df.select(
    f.collect_list("id").alias("id"),
    f.array(
        [
            f.array(
                [f.collect_list("data").getItem(j).getItem(i) 
                 for j in range(numRows)]
            ) 
            for i in range(dataLength)
        ]
    ).alias("data")
)\
.show(truncate=False)
#+---------+------------------------------------------------------------------------------+
#|id       |data                                                                          |
#+---------+------------------------------------------------------------------------------+
#|[1, 2, 3]|[WrappedArray(10, 20, 30), WrappedArray(11, 21, 31), WrappedArray(12, 22, 32)]|
#+---------+------------------------------------------------------------------------------+