如何通过Spark中的聚合在DataFrame组中执行算术运算?

时间:2018-11-30 18:32:40

标签: scala apache-spark

我有一个数据框,如下所示:

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")


+---+---+---+
| F1| F2| F3|
+---+---+---+
|  x|  y|  1|
|  x|  z|  2|
|  x|  a|  4|
|  x|  a|  5|
|  t|  y|  1|
|  t| y2|  6|
|  t| y3|  3|
|  t| y4|  5|
+---+---+---+

如何在“ F1”列上分组,并在“ F3”上相乘?

总而言之,我可以执行以下操作,但是不确定要使用哪个函数进行乘法。

df.groupBy("F1").agg(sum("F3")).show

+---+-------+
| F1|sum(F3)|
+---+-------+
|  x|     12|
|  t|     15|
+---+-------+

2 个答案:

答案 0 :(得分:3)

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")
import org.apache.spark.sql.Row
val x=df.select($"F1",$"F3").groupByKey{case r=>r.getString(0)}.reduceGroups{ ((r),(r2)) =>Row(r.getString(0),r.getInt(1)*r2.getInt(1)) }

x.show()

+-----+------------------------------------------+
|value|ReduceAggregator(org.apache.spark.sql.Row)|
+-----+------------------------------------------+
|    x|                                   [x, 40]|
|    t|                                   [t, 90]|
+-----+------------------------------------------+

答案 1 :(得分:1)

定义自定义聚合函数,如下所示:

class Product extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(StructField("value", LongType) :: Nil)

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(

    StructField("product", LongType) :: Nil
)

// This is the output type of your aggregatation function.
override def dataType: DataType = LongType

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = 1L

}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  buffer(0) = buffer.getAs[Long](0) * input.getAs[Long](0)
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer1.getAs[Long](0) * buffer2.getAs[Long](0)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer.getLong(0)
}

}

然后按如下所示在汇总中使用它:

val product = new Product

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")

df.groupBy("F1").agg(product(col("F3"))).show

以下是输出:

+---+-----------+
| F1|product(F3)|
+---+-----------+
|  x|         40|
|  t|         90|
+---+-----------+