希望将具有Spark的嵌套数组分解为批处理。下面的列是XML文件的嵌套数组。现在尝试将时间序列数据写入批处理以便写入NoSQL数据库。例如:
+-------+-----------------------+
| ID | Example |
+-------+-----------------------+
| A| [[1,2],[3,4],[5,6]] |
+-------+-----------------------+
批量为2的批次输出
+-------+-----------------------+
| ID | Example |
+-------+-----------------------+
| A| [[1,2],[3,4]] |
+-------+-----------------------+
| A| [[5,6]] |
+-------+-----------------------+
答案 0 :(得分:1)
您可以利用pyspark.sql.functions.posexplode()
将列与其在数组中显示的索引一起展开,然后将结果位置除以n
以创建组。
例如,以下是在DataFrame上使用posexplode()
的输出:
import pyspark.sql.functions as f
df.select('ID', f.posexplode('Example')).show()
#+---+---+------+
#| ID|pos| col|
#+---+---+------+
#| A| 0|[1, 2]|
#| A| 1|[3, 4]|
#| A| 2|[5, 6]|
#+---+---+------+
请注意,我们有两列:pos
和col
而不是一列。由于我们需要n
的群组,因此我们可以简单地将pos
除以n
,然后使用floor
来获取群组。
n = 2
df.select('ID', f.posexplode('Example'))\
.withColumn("group", f.floor(f.col("pos")/n))\
.show(truncate=False)
#+---+---+------+-----+
#|ID |pos|col |group|
#+---+---+------+-----+
#|A |0 |[1, 2]|0 |
#|A |1 |[3, 4]|0 |
#|A |2 |[5, 6]|1 |
#+---+---+------+-----+
现在按"ID"
和"group"
进行分组,然后使用pyspark.sql.functions.collect_list()
获取所需的输出。
df.select('ID', f.posexplode('Example'))\
.withColumn("group", f.floor(f.col("pos")/n))\
.groupBy("ID", "group")\
.agg(f.collect_list("col").alias("Example"))\
.sort("group")\
.drop("group")\
.show(truncate=False)
#+---+----------------------------------------+
#|ID |Example |
#+---+----------------------------------------+
#|A |[WrappedArray(1, 2), WrappedArray(3, 4)]|
#|A |[WrappedArray(5, 6)] |
#+---+----------------------------------------+
您会看到我也按"group"
列排序并将其删除,但根据您的需要,这是可选的。
还有一些其他方法适用于2.1以下的Spark版本。所有这些方法都产生与上述相同的输出。
<强> 1。使用udf
您可以使用udf
将阵列分组。例如:
def get_groups(array, n):
return filter(lambda x: x, [array[i*n:(i+1)*n] for i in range(len(array))])
get_groups_of_2 = f.udf(
lambda x: get_groups(x, 2),
ArrayType(ArrayType(ArrayType(IntegerType())))
)
df.select("ID", f.explode(get_groups_of_2("Example")).alias("Example"))\
.show(truncate=False)
get_groups()
函数将获取一个数组并返回一组由n个元素组成的数组。
<强> 2。使用rdd
另一个选项是序列化为rdd
,并在get_groups()
调用中使用map()
函数。然后转换回DataFrame。您必须指定此转换的架构才能正常工作。
n = 2
schema = StructType(
[
StructField("ID", StringType()),
StructField("Example", ArrayType(ArrayType(ArrayType(IntegerType()))))
]
)
df.rdd.map(lambda x: (x["ID"], get_groups(x["Example"], n=n)))\
.toDF(schema)\
.select("ID", f.explode("Example").alias("Example"))\
.show(truncate=False)