以下Spark代码正确演示了我想要做的事情,并使用一个小型演示数据集生成正确的输出。
当我在大量生产数据上运行相同的一般类型的代码时,我遇到了运行时问题。 Spark作业在我的集群上运行大约12个小时并且失败了。
只要看一下下面的代码,爆炸每一行似乎都是低效的,只是将它合并回来。在给定的测试数据集中,第四行包含array_value_1中的三个值和array_value_2中的三个值,它们将爆炸为3 * 3或9个爆炸行。
那么,在一个更大的数据集中,一行有五个这样的数组列,每列有十个值,会爆炸成10 ^ 5个爆炸行?
查看提供的Spark函数,没有开箱即用的功能可以实现我想要的功能。我可以提供用户定义的功能。这有什么速度上的缺点吗?
val sparkSession = SparkSession.builder.
master("local")
.appName("merge list test")
.getOrCreate()
val schema = StructType(
StructField("category", IntegerType) ::
StructField("array_value_1", ArrayType(StringType)) ::
StructField("array_value_2", ArrayType(StringType)) ::
Nil)
val rows = List(
Row(1, List("a", "b"), List("u", "v")),
Row(1, List("b", "c"), List("v", "w")),
Row(2, List("c", "d"), List("w")),
Row(2, List("c", "d", "e"), List("x", "y", "z"))
)
val df = sparkSession.createDataFrame(rows.asJava, schema)
val dfExploded = df.
withColumn("scalar_1", explode(col("array_value_1"))).
withColumn("scalar_2", explode(col("array_value_2")))
// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")
val dfOutput = dfExploded.groupBy("category").agg(
collect_set("scalar_1").alias("combined_values_2"),
collect_set("scalar_2").alias("combined_values_2"))
dfOutput.show()
答案 0 :(得分:18)
explode
可能效率低下,但从根本上说,您尝试实施的操作非常昂贵。实际上它只是另一个groupByKey
而你在这里做的并不多,可以让它变得更好。既然你使用Spark> 2.0你可以直接collect_list
展平:
import org.apache.spark.sql.functions.{collect_list, udf}
val flatten_distinct = udf(
(xs: Seq[Seq[String]]) => xs.flatten.distinct)
df
.groupBy("category")
.agg(
flatten_distinct(collect_list("array_value_1")),
flatten_distinct(collect_list("array_value_2"))
)
在Spark> = 2.4中,您可以将udf替换为内置函数的组合:
import org.apache.spark.sql.functions.{array_distinct, flatten}
val flatten_distinct = (array_distinct _) compose (flatten _)
也可以使用custom Aggregator
,但我怀疑其中任何一项都会产生巨大的影响。
如果集合相对较大并且您期望重复数量很多,则可以尝试将aggregateByKey
用于可变集:
import scala.collection.mutable.{Set => MSet}
val rdd = df
.select($"category", struct($"array_value_1", $"array_value_2"))
.as[(Int, (Seq[String], Seq[String]))]
.rdd
val agg = rdd
.aggregateByKey((MSet[String](), MSet[String]()))(
{case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
{case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
)
.mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
.toDF