在when子句中SparkSQL自定义函数

时间:2018-05-11 19:24:15

标签: scala apache-spark apache-spark-sql

我有一个像这样的DataFrame:

id val1 val2
------------
 1  v11  v12
 2  v21  v22
 3  v31  v32
 4  v41  v42
 5  v51  v52
 6  v61  v62

每一行代表一个可能属于一个或多个组的人。我有一个函数,它接受每一行的值,并确定该人是否符合特定组的标准:

def isInGroup: Boolean = f(group: Int)(id: String, v1: String, v2: String)

我试图像这样输出一个DataFrame:

Group1 Group2 Group3 Group4
---------------------------
     3      0      6      1

到目前为止,这是我的代码,但不起作用。不幸的是,when子句只接受Column类型的参数,而我的函数不起作用。用户定义的功能也不起作用。如果可能的话,我真的很想坚持使用select / struct /。

val summaryDF = dataDF
    .select(struct(
        sum(when(isInGroup(1)($"id", $"val1", $"val2"), value = 1)).as("Group1")),
        sum(when(isInGroup(2)($"id", $"val1", $"val2"), value = 1)).as("Group2")),
        sum(when(isInGroup(3)($"id", $"val1", $"val2"), value = 1)).as("Group3")),
        sum(when(isInGroup(4)($"id", $"val1", $"val2"), value = 1)).as("Group4"))
    ))

1 个答案:

答案 0 :(得分:0)

正如我在my previous answer中所示,您需要一个udf

import org.apache.spark.sql.functions.udf 

def isInGroupUDF(group: Int) = udf(isInGroup(group) _)

sum(when(
  isInGroupUDF(1)($"id", $"val1", $"val2"), 1
)).as("Group1")

如果您想避免列出列,可以尝试使用默认参数:

def isInGroupUDF(group: Int, id: Column = $"id", 
                 v1: Column = $"val1", v2: Column = $"val2") = {
  val f = udf(isInGroup(group) _)
  f(id, v1, v2)
}

sum(when(
  isInGroupUDF(1), 1
)).as("Group1")