我有一个像这样的DataFrame:
id val1 val2
------------
1 v11 v12
2 v21 v22
3 v31 v32
4 v41 v42
5 v51 v52
6 v61 v62
每一行代表一个可能属于一个或多个组的人。我有一个函数,它接受每一行的值,并确定该人是否符合特定组的标准:
def isInGroup: Boolean = f(group: Int)(id: String, v1: String, v2: String)
我试图像这样输出一个DataFrame:
Group1 Group2 Group3 Group4
---------------------------
3 0 6 1
到目前为止,这是我的代码,但不起作用。不幸的是,when子句只接受Column
类型的参数,而我的函数不起作用。用户定义的功能也不起作用。如果可能的话,我真的很想坚持使用select / struct /。
val summaryDF = dataDF
.select(struct(
sum(when(isInGroup(1)($"id", $"val1", $"val2"), value = 1)).as("Group1")),
sum(when(isInGroup(2)($"id", $"val1", $"val2"), value = 1)).as("Group2")),
sum(when(isInGroup(3)($"id", $"val1", $"val2"), value = 1)).as("Group3")),
sum(when(isInGroup(4)($"id", $"val1", $"val2"), value = 1)).as("Group4"))
))
答案 0 :(得分:0)
正如我在my previous answer中所示,您需要一个udf
:
import org.apache.spark.sql.functions.udf
def isInGroupUDF(group: Int) = udf(isInGroup(group) _)
sum(when(
isInGroupUDF(1)($"id", $"val1", $"val2"), 1
)).as("Group1")
如果您想避免列出列,可以尝试使用默认参数:
def isInGroupUDF(group: Int, id: Column = $"id",
v1: Column = $"val1", v2: Column = $"val2") = {
val f = udf(isInGroup(group) _)
f(id, v1, v2)
}
sum(when(
isInGroupUDF(1), 1
)).as("Group1")