如何在spark中迭代分组数据?

时间:2017-05-23 09:43:18

标签: java sql scala apache-spark

我有这样的数据集:

uid    group_a    group_b
1      3          unkown
1      unkown     4
2      unkown     3
2      2          unkown

我想得到结果:

uid    group_a    group_b
1      3          4
2      2          3

我尝试将数据分组为" uid"并迭代每个组并选择not-unkown值作为最终值,但不知道该怎么做。

3 个答案:

答案 0 :(得分:0)

将数据集格式化为PairRDD后,可以使用reduceByKey操作查找单个已知值。以下示例假定每个uid只有一个已知值,否则返回第一个已知值

val input = List(
    ("1", "3", "unknown"),
    ("1", "unknown", "4"),
    ("2", "unknown", "3"),
    ("2", "2", "unknown")
)

val pairRdd = sc.parallelize(input).map(l => (l._1, (l._2, l._3)))

val result = pairRdd.reduceByKey { (a, b) => 
    val groupA = if (a._1 != "unknown") a._1 else b._1
    val groupB = if (a._2 != "unknown") a._2 else b._2
    (groupA, groupB)
}

结果将是一个看起来像这样的pairRdd

(uid, (group_a, group_b))
(1,(3,4))                                                                       
(2,(2,3))

您可以使用简单的地图操作返回普通线路格式。

答案 1 :(得分:0)

我建议您定义User Defined Aggregation FunctionUDAF

使用inbuilt functions是很好的方法,但很难定制。如果您拥有UDAF,那么它可以自定义,您可以根据需要进行编辑。

关于您的问题,以下可以是您的解决方案。您可以根据需要进行编辑。

首要任务是定义UDAF

class PingJiang extends UserDefinedAggregateFunction {

  def inputSchema = new StructType().add("group_a", StringType).add("group_b", StringType)
  def bufferSchema = new StructType().add("buff0", StringType).add("buff1", StringType)
  def dataType = StringType
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, "")
    buffer.update(1, "")
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0)) {
      val buff = buffer.getString(0)
      val groupa = input.getString(0)
      val groupb = input.getString(1)

      if(!groupa.equalsIgnoreCase("unknown")){
        buffer.update(0, groupa)
      }
      if(!groupb.equalsIgnoreCase("unknown")){
        buffer.update(1, groupb)
      }
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    val buff1 = buffer1.getString(0)+buffer2.getString(0)
    val buff2 = buffer1.getString(1)+buffer2.getString(1)
    buffer1.update(0, buff1+","+buff2)
  }

  def evaluate(buffer: Row) : String = {
    buffer.getString(0)
  }
}

然后你从你的main课程中调用它并做一些操作来获得你需要的结果

val data = Seq(
  (1, "3", "unknown"),
  (1, "unknown", "4"),
  (2, "unknown", "3"),
  (2, "2", "unknown"))
  .toDF("uid", "group_a", "group_b")

val udaf = new PingJiang()

val result = data.groupBy("uid").agg(udaf($"group_a", $"group_b").as("ping"))
  .withColumn("group_a", split($"ping", ",")(0))
  .withColumn("group_b", split($"ping", ",")(1))
  .drop("ping")
result.show(false)

访问databricksaugmentiq,以便更好地了解UDAF

注意:上述解决方案会为您提供每个组的最新值(如果有的话)(您可以随时根据需要进行编辑)

答案 2 :(得分:0)

您可以用"unknown"替换所有null值,然后在地图中使用函数first()(如图here所示),以获得第一个非空值每组每列中的值:

import org.apache.spark.sql.functions.{col,first,when}
// We are only gonna apply our function to the last 2 columns
val cols = df.columns.drop(1)
// Create expression
val exprs = cols.map(first(_,true))
// Putting it all together
df.select(df.columns
          .map(c => when(col(c) === "unknown", null)
          .otherwise(col(c)).as(c)): _*)
  .groupBy("uid")
  .agg(exprs.head, exprs.tail: _*).show()
+---+--------------------+--------------------+
|uid|first(group_1, true)|first(group_b, true)|
+---+--------------------+--------------------+
|  1|                   3|                   4|
|  2|                   2|                   3|
+---+--------------------+--------------------+

数据:

val df = sc.parallelize(Array(("1","3","unknown"),("1","unknown","4"),
                              ("2","unknown","3"),("2","2","unknown"))).toDF("uid","group_1","group_b")