爆炸阵列大小为'n'

时间:2018-05-15 21:12:00

标签: apache-spark pyspark

希望将具有Spark的嵌套数组分解为批处理。下面的列是XML文件的嵌套数组。现在尝试将时间序列数据写入批处理以便写入NoSQL数据库。例如:

+-------+-----------------------+
|   ID  |               Example |
+-------+-----------------------+
|      A|   [[1,2],[3,4],[5,6]] |
+-------+-----------------------+

批量为2的批次输出

+-------+-----------------------+
|   ID  |               Example |
+-------+-----------------------+
|      A|         [[1,2],[3,4]] |
+-------+-----------------------+
|      A|               [[5,6]] |
+-------+-----------------------+

1 个答案:

答案 0 :(得分:1)

对于Spark v 2.1 +

您可以利用pyspark.sql.functions.posexplode()将列与其在数组中显示的索引一起展开,然后将结果位置除以n以创建组。

例如,以下是在DataFrame上使用posexplode()的输出:

import pyspark.sql.functions as f
df.select('ID', f.posexplode('Example')).show()
#+---+---+------+
#| ID|pos|   col|
#+---+---+------+
#|  A|  0|[1, 2]|
#|  A|  1|[3, 4]|
#|  A|  2|[5, 6]|
#+---+---+------+

请注意,我们有两列:poscol而不是一列。由于我们需要n的群组,因此我们可以简单地将pos除以n,然后使用floor来获取群组。

n = 2
df.select('ID', f.posexplode('Example'))\
    .withColumn("group", f.floor(f.col("pos")/n))\
    .show(truncate=False)
#+---+---+------+-----+
#|ID |pos|col   |group|
#+---+---+------+-----+
#|A  |0  |[1, 2]|0    |
#|A  |1  |[3, 4]|0    |
#|A  |2  |[5, 6]|1    |
#+---+---+------+-----+

现在按"ID""group"进行分组,然后使用pyspark.sql.functions.collect_list()获取所需的输出。

df.select('ID', f.posexplode('Example'))\
    .withColumn("group", f.floor(f.col("pos")/n))\
    .groupBy("ID", "group")\
    .agg(f.collect_list("col").alias("Example"))\
    .sort("group")\
    .drop("group")\
    .show(truncate=False)
#+---+----------------------------------------+
#|ID |Example                                 |
#+---+----------------------------------------+
#|A  |[WrappedArray(1, 2), WrappedArray(3, 4)]|
#|A  |[WrappedArray(5, 6)]                    |
#+---+----------------------------------------+

您会看到我也按"group"列排序并将其删除,但根据您的需要,这是可选的。

适用于旧版本的Spark

还有一些其他方法适用于2.1以下的Spark版本。所有这些方法都产生与上述相同的输出。

<强> 1。使用udf

您可以使用udf将阵列分组。例如:

def get_groups(array, n):
    return filter(lambda x: x, [array[i*n:(i+1)*n] for i in range(len(array))])

get_groups_of_2 = f.udf(
    lambda x: get_groups(x, 2),
    ArrayType(ArrayType(ArrayType(IntegerType())))
)

df.select("ID", f.explode(get_groups_of_2("Example")).alias("Example"))\
    .show(truncate=False)

get_groups()函数将获取一个数组并返回一组由n个元素组成的数组。

<强> 2。使用rdd

另一个选项是序列化为rdd,并在get_groups()调用中使用map()函数。然后转换回DataFrame。您必须指定此转换的架构才能正常工作。

n = 2

schema = StructType(
    [
        StructField("ID", StringType()),
        StructField("Example", ArrayType(ArrayType(ArrayType(IntegerType()))))
    ]
)

df.rdd.map(lambda x: (x["ID"], get_groups(x["Example"], n=n)))\
    .toDF(schema)\
    .select("ID", f.explode("Example").alias("Example"))\
    .show(truncate=False)