DataFrame:在groupBy上应用自定义函数

时间:2018-05-30 13:13:57

标签: scala apache-spark dataframe

我有一个带有以下结构的Dataframe DF:DF(标记,值)和与整个Dataframe相关的分数(Double值)。

我有一个函数,它接受参数一个Dataframe和一个分数,并返回一个新值。我们可以像这样代表这个函数:

def computeNewValue(DF:DataFrame, score:Double):Double={
  // Computing that depends of 'value' and all data in 'DF'
}

当我想要获得我的“新价值”时,一切正常。对于整个Dataframe。

但是现在,我想将此功能应用于Dataframe,而不是Dataframe中的组。我希望得到一个新的价值观。对于每个标签。让我们举个例如我的得分'只是值列的最大值。我想要一些我可以这样使用的东西:

DF.groupBy($"tag")
  .agg(max($"value").as("score"))
  .withColumn("newValue",computeNewValue([MyTagGroup],$"score") // To change

DF.groupBy($"tag")
      .agg(
         max($"value").as("score"), 
         computeNewValue([MyTagGroup],$"score").as("newValue") // To change
       )

有没有一种很好的方法来实现这样的目标?

编辑:

我尝试使用UDAF解决我的问题,但似乎不是这样做的好方法。以下是我正在尝试做的更多细节:

我的函数computeNewValue的完整代码:

def computeNewValue(DF:DataFrame,score:Double):Double={
     DF.select($"*",(percent_rank over Window.orderBy("value")).alias("percentile")) 
                  .filter($"percentile">=0.05 && $"percentile"<=0.95)                  
                  .agg(min("value").as("min"),max("value").as("max")) 
                  .withColumn("newValue",(lit(score)/($"max"-$"min"))*100)
                  .select("newValue")
                  .collect()(0).getDouble(0)
}

这是我的UDAF:

class computeNewValue_udaf extends UserDefinedAggregateFunction {

  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", DoubleType) :: StructField("score", DoubleType) :: Nil)

  override def bufferSchema: StructType = StructType(
    StructField("min", DoubleType) ::
    StructField("max", DoubleType) :: Nil
  )

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
     buffer(0)=0
     buffer(1)=0
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

        if(buffer(0).asInstanceOf[Int]>input(0).asInstanceOf[Int]){
          buffer(0)=input
        }

        if(buffer(1).asInstanceOf[Int]<input(0).asInstanceOf[Int]){
          buffer(1)=input
        }
        // Can't specify the 'percentile' part
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // ?
  }

  override def evaluate(buffer: Row): Any = {
    score*100 / (buffer(1).asInstanceOf[Int]-buffer(0).asInstanceOf[Int]) 
    // Do not work because can't access 'score' value
  }
}

0 个答案:

没有答案