给出df
如下:
val df = spark.createDataFrame(Seq(
(1, 2, 3),
(3, 2, 1)
)).toDF("One", "Two", "Three")
我想写一个udf
来Three columns
作为inout;并根据类似如下的最高输入值返回新列:
import org.apache.spark.sql.functions.udf
def udfScoreToCategory=udf((One: Int, Two: Int, Three: Int): Int => {
cols match {
case cols if One > Two && One > Three => 1
case cols if Two > One && Two > Three => 2
case _ => 0
}}
看看如何与vector type
作为输入类似,将会很有趣:
import org.apache.spark.ml.linalg.Vector
def udfVectorToCategory=udf((cols:org.apache.spark.ml.linalg.Vector): Int => {
cols match {
case cols if cols(0) > cols(1) && cols(0) > cols(2) => 1,
case cols if cols(1) > cols(0) && cols(1) > cols(2) => 2
case _ => 0
}})
答案 0 :(得分:1)
我能够通过以下方式找到矢量的最大元素:
val vectorToCluster = udf{ (x: Vector) => x.argmax }
但是,我仍然对如何在多列值上进行模式匹配感到困惑。
答案 1 :(得分:1)
一些问题:
cols
不在范围内。(...): T => ...
不是匿名函数的有效语法。val
使用def
。定义此方法的一种方法:
val udfScoreToCategory = udf[Int, (Int, Int, Int)]{
case (one, two, three) if one > two && one > three => 1
case (one, two, three) if two > one && two > three => 2
case _ => 0
}
和
val udfVectorToCategory = udf[Int, org.apache.spark.ml.linalg.Vector]{
_.toArray match {
case Array(one, two, three) if one > two && one > three => 1
case Array(one, two, three) if two > one && two > three => 2
case _ => 0
}}
一般来说,对于第一种情况,你应该使用``when`
import org.apache.spark.sql.functions.when
when ($"one" > $"two" && $"one" > $"three", 1)
.when ($"two" > $"one" && $"two" > $"three", 2)
.otherwise(0)
其中one
,two
,three
是列名。