Spark UDAF:如何通过UDAF(用户自定义聚合函数)中的列字段名称从输入中获取值?

时间:2018-01-15 04:09:06

标签: scala apache-spark apache-spark-sql aggregate user-defined-aggregate

我正在尝试使用Spark UDAF将两个现有列汇总到一个新列中。 Spark UDAF上的大多数教程都使用索引来获取输入行的每一列中的值。像这样:

input.getAs[String](1)

,在我的更新方法(override def update(buffer: MutableAggregationBuffer, input: Row): Unit)中使用。它也适用于我的情况。但是,我想使用该列的字段名称来获取该值。像这样:

input.getAs[String](ColumnNames.BehaviorType)

,其中ColumnNames.BehaviorType是在对象中定义的String对象:

 /**
    * Column names in the original dataset
    */
  object ColumnNames {
    val JobSeekerID = "JobSeekerID"
    val JobID = "JobID"
    val Date = "Date"
    val BehaviorType = "BehaviorType"
  }

这次它不起作用。我得到以下例外:

  

java.lang.IllegalArgumentException:字段“BehaviorType”没有   存在。在   org.apache.spark.sql.types.StructType $$ anonfun $ $字段索引1.适用(StructType.scala:292)   ...在org.apache.spark.sql.Row $ class.getAs(Row.scala:333)at   org.apache.spark.sql.catalyst.expressions.GenericRow.getAs(rows.scala:165)     在   com.recsys.UserBehaviorRecordsUDAF.update(UserBehaviorRecordsUDAF.scala:44)

一些相关的代码段:

这是我的UDAF的一部分:

class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {


  override def inputSchema: StructType = StructType(
    StructField("JobID", IntegerType) ::
      StructField("BehaviorType", StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
    println(input.schema.treeString)
    println
    println(input.mkString(","))
    println
    println(this.inputSchema.treeString)
//    println
//    println(bufferSchema.treeString)

    input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
      case BehaviourTypes.viewed_job =>
        buffer(0) =
          buffer.getAs[Seq[Int]](0) :+ //Array[Int]  //TODO WHY??
          input.getAs[Int](0) //ColumnNames.JobID
      case BehaviourTypes.bookmarked_job =>
        buffer(1) =
          buffer.getAs[Seq[Int]](1) :+ //Array[Int]
            input.getAs[Int](0)//ColumnNames.JobID
      case BehaviourTypes.applied_job =>
        buffer(2) =
          buffer.getAs[Seq[Int]](2) :+  //Array[Int]
            input.getAs[Int](0) //ColumnNames.JobID
    }
  }

以下是调用UDAF的代码部分:

val ubrUDAF = new UserBehaviorRecordsUDAF

val userProfileDF = userBehaviorDS
  .groupBy(ColumnNames.JobSeekerID)
  .agg(
    ubrUDAF(
      userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
      userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
    ).as("profile str"))

输入行的架构中的字段名称似乎未传递到UDAF:

XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
root
 |-- input0: integer (nullable = true)
 |-- input1: string (nullable = true)


30917,viewed_job

root
 |-- JobID: integer (nullable = true)
 |-- BehaviorType: string (nullable = true)

我的代码有什么问题?

1 个答案:

答案 0 :(得分:0)

我还想在update方法中使用inputSchema中的字段名称来创建可维护的代码。

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class MyUDAF extends UserDefinedAggregateFunction {
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    val inputWSchema = new GenericRowWithSchema(input.toSeq.toArray, inputSchema)

最终切换到运行了一半时间的Aggregator。