我有一个我试图压扁的Dataframe。作为整个过程的一部分,我想爆炸它,所以如果我有一列数组,那么数组的每个值都将用于创建一个单独的行。例如,
id | name | likes
_______________________________
1 | Luke | [baseball, soccer]
应该成为
id | name | likes
_______________________________
1 | Luke | baseball
1 | Luke | soccer
这是我的代码
private DataFrame explodeDataFrame(DataFrame df) {
DataFrame resultDf = df;
for (StructField field : df.schema().fields()) {
if (field.dataType() instanceof ArrayType) {
resultDf = resultDf.withColumn(field.name(), org.apache.spark.sql.functions.explode(resultDf.col(field.name())));
resultDf.show();
}
}
return resultDf;
}
问题是在我的数据中,一些数组列有空值。在这种情况下,整个行都将被删除。所以这个数据框:
id | name | likes
_______________________________
1 | Luke | [baseball, soccer]
2 | Lucy | null
成为
id | name | likes
_______________________________
1 | Luke | baseball
1 | Luke | soccer
而不是
id | name | likes
_______________________________
1 | Luke | baseball
1 | Luke | soccer
2 | Lucy | null
如何爆炸我的数组以便我不会丢失空行?
我正在使用Spark 1.5.2和Java 8
答案 0 :(得分:48)
Spark 2.2 +
您可以使用explode_outer
功能:
import org.apache.spark.sql.functions.explode_outer
df.withColumn("likes", explode_outer($"likes")).show
// +---+----+--------+
// | id|name| likes|
// +---+----+--------+
// | 1|Luke|baseball|
// | 1|Luke| soccer|
// | 2|Lucy| null|
// +---+----+--------+
Spark< = 2.1
在Scala中,Java等价物应该几乎相同(导入单个函数使用import static
)。
import org.apache.spark.sql.functions.{array, col, explode, lit, when}
val df = Seq(
(1, "Luke", Some(Array("baseball", "soccer"))),
(2, "Lucy", None)
).toDF("id", "name", "likes")
df.withColumn("likes", explode(
when(col("likes").isNotNull, col("likes"))
// If null explode an array<string> with a single null
.otherwise(array(lit(null).cast("string")))))
这里的想法基本上是将NULL
替换为所需类型的array(NULL)
。对于复杂类型(a.k.a structs
),您必须提供完整的架构:
val dfStruct = Seq((1L, Some(Array((1, "a")))), (2L, None)).toDF("x", "y")
val st = StructType(Seq(
StructField("_1", IntegerType, false), StructField("_2", StringType, true)
))
dfStruct.withColumn("y", explode(
when(col("y").isNotNull, col("y"))
.otherwise(array(lit(null).cast(st)))))
或
dfStruct.withColumn("y", explode(
when(col("y").isNotNull, col("y"))
.otherwise(array(lit(null).cast("struct<_1:int,_2:string>")))))
注意强>:
如果已创建数组Column
且containsNull
设置为false
,则应首先更改此项(使用Spark 2.1测试):
df.withColumn("array_column", $"array_column".cast(ArrayType(SomeType, true)))
答案 1 :(得分:1)
对接受的答案进行跟进,当数组元素是复杂类型时,很难手动定义它(例如,使用大型结构)。
为了自动完成,我编写了以下帮助方法:
def explodeOuter(df: Dataset[Row], columnsToExplode: List[String]) = {
val arrayFields = df.schema.fields
.map(field => field.name -> field.dataType)
.collect { case (name: String, type: ArrayType) => (name, type.asInstanceOf[ArrayType])}
.toMap
columnsToExplode.foldLeft(df) { (dataFrame, arrayCol) =>
dataFrame.withColumn(arrayCol, explode(when(size(col(arrayCol)) =!= 0, col(arrayCol))
.otherwise(array(lit(null).cast(arrayFields(arrayCol).elementType)))))
}
答案 2 :(得分:1)
您可以使用explode_outer()
功能。
答案 3 :(得分:0)
要处理空的地图类型列:对于Spark <= 2.1
List((1, Array(2, 3, 4), Map(1 -> "a")),
(2, Array(5, 6, 7), Map(2 -> "b")),
(3, Array[Int](), Map[Int, String]())).toDF("col1", "col2", "col3").show()
df.select('col1, explode(when(size(map_keys('col3)) === 0, map(lit("null"), lit("null"))).
otherwise('col3))).show()
答案 4 :(得分:0)
from pyspark.sql.functions import *
def flatten_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(flat_cols +
[col(nc + '.' + c).alias(nc + '_' + c)
for nc in nested_cols
for c in nested_df.select(nc + '.*').columns])
print("flatten_df_count :", flat_df.count())
return flat_df
def explode_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct' and c[1][:5] != 'array']
array_cols = [c[0] for c in nested_df.dtypes if c[1][:5] == 'array']
for array_col in array_cols:
schema = new_df.select(array_col).dtypes[0][1]
nested_df = nested_df.withColumn(array_col, when(col(array_col).isNotNull(), col(array_col)).otherwise(array(lit(None)).cast(schema)))
nested_df = nested_df.withColumn("tmp", arrays_zip(*array_cols)).withColumn("tmp", explode("tmp")).select([col("tmp."+c).alias(c) for c in array_cols] + flat_cols)
print("explode_dfs_count :", nested_df.count())
return nested_df
new_df = flatten_df(myDf)
while True:
array_cols = [c[0] for c in new_df.dtypes if c[1][:5] == 'array']
if len(array_cols):
new_df = flatten_df(explode_df(new_df))
else:
break
new_df.printSchema()
使用 arrays_zip
和 explode
来加快速度并解决 null
问题。