使用Apache Spark flattern每个组的2个第一行使用Java

时间:2016-11-24 17:56:04

标签: java mysql apache-spark hive

给出以下输入表:

+----+------------+----------+
| id | shop       | purchases|
+----+------------+----------+
|  1 | 01         |       20 |
|  1 | 02         |       31 |
|  2 | 03         |        5 |
|  1 | 03         |        3 |
+----+------------+----------+

我想按ID分组并根据购买情况获得前2家顶级商店:

+----+-------+------+
| id | top_1 | top_2|
+----+-------+------+
|  1 | 02    |   01 |
|  2 | 03    |      |
+----+-------+------+

我使用的是Apache Spark 2.0.1,第一个表是数据集中其他查询和连接的结果。我可以用传统的java迭代数据集来做到这一点,但我希望有另一种使用数据集功能的方法。 我的第一次尝试如下:

//dataset is already ordered by id, purchases desc
...
Dataset<Row> ds = dataset.repartition(new Column("id"));
ds.foreachPartition(new ForeachPartitionFunction<Row>() {
        @Override
        public void call(Iterator<Row> itrtr) throws Exception {
            int counter = 0;
            while (itrtr.hasNext()) {
                Row row = itrtr.next();
                if(counter < 2)
                //save it into another Dataset
                counter ++;
            }
        }
    });

但后来我迷失了如何将其保存到另一个数据集中。我的目标是,最后将结果保存到MySQL表中。

1 个答案:

答案 0 :(得分:2)

使用窗口功能和透视,您可以定义一个窗口:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, first, row_number}

val w = Window.partitionBy(col("id")).orderBy(col("purchases").desc)

添加row_number并过滤前两行:

val dataset = Seq(
  (1, "01", 20), (1, "02", 31), (2, "03", 5), (1, "03", 3)
).toDF("id", "shop", "purchases")

val topTwo = dataset.withColumn("top", row_number.over(w)).where(col("top") <= 2)

和pivot:

topTwo.groupBy(col("id")).pivot("top", Seq(1, 2)).agg(first("shop"))

结果为:

+---+---+----+
| id|  1|   2|
+---+---+----+
|  1| 02|  01|
|  2| 03|null|
+---+---+----+

我将把语法转换为Java作为海报的练习(不包括import static的函数,其余函数应该接近相同)。