我正在尝试使用Spark UDAF将两个现有列汇总到一个新列中。 Spark UDAF上的大多数教程都使用索引来获取输入行的每一列中的值。像这样:
input.getAs[String](1)
,在我的更新方法(override def update(buffer: MutableAggregationBuffer, input: Row): Unit
)中使用。它也适用于我的情况。但是,我想使用该列的字段名称来获取该值。像这样:
input.getAs[String](ColumnNames.BehaviorType)
,其中ColumnNames.BehaviorType是在对象中定义的String对象:
/**
* Column names in the original dataset
*/
object ColumnNames {
val JobSeekerID = "JobSeekerID"
val JobID = "JobID"
val Date = "Date"
val BehaviorType = "BehaviorType"
}
这次它不起作用。我得到以下例外:
java.lang.IllegalArgumentException:字段“BehaviorType”没有 存在。在 org.apache.spark.sql.types.StructType $$ anonfun $ $字段索引1.适用(StructType.scala:292) ...在org.apache.spark.sql.Row $ class.getAs(Row.scala:333)at org.apache.spark.sql.catalyst.expressions.GenericRow.getAs(rows.scala:165) 在 com.recsys.UserBehaviorRecordsUDAF.update(UserBehaviorRecordsUDAF.scala:44)
一些相关的代码段:
这是我的UDAF的一部分:
class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("JobID", IntegerType) ::
StructField("BehaviorType", StringType) :: Nil)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
println(input.schema.treeString)
println
println(input.mkString(","))
println
println(this.inputSchema.treeString)
// println
// println(bufferSchema.treeString)
input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
case BehaviourTypes.viewed_job =>
buffer(0) =
buffer.getAs[Seq[Int]](0) :+ //Array[Int] //TODO WHY??
input.getAs[Int](0) //ColumnNames.JobID
case BehaviourTypes.bookmarked_job =>
buffer(1) =
buffer.getAs[Seq[Int]](1) :+ //Array[Int]
input.getAs[Int](0)//ColumnNames.JobID
case BehaviourTypes.applied_job =>
buffer(2) =
buffer.getAs[Seq[Int]](2) :+ //Array[Int]
input.getAs[Int](0) //ColumnNames.JobID
}
}
以下是调用UDAF的代码部分:
val ubrUDAF = new UserBehaviorRecordsUDAF
val userProfileDF = userBehaviorDS
.groupBy(ColumnNames.JobSeekerID)
.agg(
ubrUDAF(
userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
).as("profile str"))
输入行的架构中的字段名称似乎未传递到UDAF:
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
root
|-- input0: integer (nullable = true)
|-- input1: string (nullable = true)
30917,viewed_job
root
|-- JobID: integer (nullable = true)
|-- BehaviorType: string (nullable = true)
我的代码有什么问题?
答案 0 :(得分:0)
我还想在update方法中使用inputSchema中的字段名称来创建可维护的代码。
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class MyUDAF extends UserDefinedAggregateFunction {
def update(buffer: MutableAggregationBuffer, input: Row) = {
val inputWSchema = new GenericRowWithSchema(input.toSeq.toArray, inputSchema)
最终切换到运行了一半时间的Aggregator。