spark多个列和集合元素的UDF模式匹配

时间:2017-08-16 12:26:22

标签: scala apache-spark user-defined-functions

给出df如下:

enter image description here

val df = spark.createDataFrame(Seq(
(1, 2, 3),
(3, 2, 1)
)).toDF("One", "Two", "Three")

with schema: enter image description here

我想写一个udfThree columns作为inout;并根据类似如下的最高输入值返回新列:

import org.apache.spark.sql.functions.udf


def udfScoreToCategory=udf((One: Int, Two: Int, Three: Int): Int => {
    cols match {
    case cols if One > Two && One > Three => 1
    case cols if Two > One && Two > Three => 2
    case _ => 0
}}

看看如何与vector type作为输入类似,将会很有趣:

import org.apache.spark.ml.linalg.Vector

def udfVectorToCategory=udf((cols:org.apache.spark.ml.linalg.Vector): Int => {
    cols match {
    case cols if cols(0) > cols(1) && cols(0) > cols(2) => 1,
    case cols if cols(1) > cols(0) && cols(1) > cols(2) => 2
    case _ => 0
}})

2 个答案:

答案 0 :(得分:1)

我能够通过以下方式找到矢量的最大元素:

  val vectorToCluster = udf{ (x: Vector) => x.argmax }

但是,我仍然对如何在多列值上进行模式匹配感到困惑。

答案 1 :(得分:1)

一些问题:

    第一个示例中的
  • cols不在范围内。
  • (...): T => ...不是匿名函数的有效语法。
  • 最好在val使用def

定义此方法的一种方法:

val udfScoreToCategory = udf[Int, (Int, Int, Int)]{
  case (one, two, three) if one > two && one > three => 1
  case (one, two, three) if two > one && two > three => 2
  case _ => 0
}

val udfVectorToCategory = udf[Int, org.apache.spark.ml.linalg.Vector]{
  _.toArray match {
    case Array(one, two, three) if one > two && one > three => 1
    case Array(one, two, three) if two > one && two > three => 2
    case _ => 0
}}

一般来说,对于第一种情况,你应该使用``when`

import org.apache.spark.sql.functions.when

when ($"one" > $"two" && $"one" > $"three", 1)
  .when ($"two" > $"one" && $"two" > $"three", 2)
  .otherwise(0)

其中onetwothree是列名。