如何使用Scala在Spark中滑动窗口等级?

时间:2019-05-14 02:51:13

标签: scala apache-spark window apache-spark-mllib ranking-functions

我有一个数据集:

+-----+-------------------+---------------------+------------------+
|query|similar_queries    |model_score          |count             |
+-----+-------------------+---------------------+------------------+
|shirt|funny shirt        |0.0034038130658784866|189.0             |
|shirt|shirt womens       |0.0019435265241921438|136.0             |
|shirt|watch              |0.001097496453284101 |212.0             |
|shirt|necklace           |6.694577024597908E-4 |151.0             |
|shirt|white shirt        |0.0037413097560623485|217.0             |
|shirt|shoes              |0.0022062579255572733|575.0             |
|shirt|crop top           |9.065831060804897E-4 |173.0             |
|shirt|polo shirts for men|0.007706416273211698 |349.0             |
|shirt|shorts             |0.002669621942466027 |200.0             |
|shirt|black shirt        |0.03264296242546658  |114.0             |
+-----+-------------------+---------------------+------------------+

我首先根据“计数”对数据集进行排名。

lazy val countWindowByFreq = Window.partitionBy(col(QUERY)).orderBy(col(COUNT).desc)
val ranked_data = data.withColumn("count_rank", row_number over countWindowByFreq)

+-----+-------------------+---------------------+------------------+----------+
|query|similar_queries    |model_score          |count             |count_rank|
+-----+-------------------+---------------------+------------------+----------+
|shirt|shoes              |0.0022062579255572733|575.0             |1         |
|shirt|polo shirts for men|0.007706416273211698 |349.0             |2         |
|shirt|white shirt        |0.0037413097560623485|217.0             |3         |
|shirt|watch              |0.001097496453284101 |212.0             |4         |
|shirt|shorts             |0.002669621942466027 |200.0             |5         |
|shirt|funny shirt        |0.0034038130658784866|189.0             |6         |
|shirt|crop top           |9.065831060804897E-4 |173.0             |7         |
|shirt|necklace           |6.694577024597908E-4 |151.0             |8         |
|shirt|shirt womens       |0.0019435265241921438|136.0             |9         |
|shirt|black shirt        |0.03264296242546658  |114.0             |10        |
+-----+-------------------+---------------------+------------------+----------+

我现在正在尝试使用row_number(4行)上的滚动窗口对内容进行排名,并基于model_score在窗口内进行排名。例如:

在第一个窗口中,行编号1到4,新排名(新列)将为

1. polo shirts for men
2. white shirt
3. shoes
4. watch

在第一个窗口中,行号5到8,新排名(新列)将为

5. funny shirt
6. shorts
7. shirt womens 
8. crop top

在第一个窗口中,第9行要休息,新排名(新列)将为

9. black shirt 
10. shirt womens

有人可以告诉我,使用Spark和Scala是否可以实现?我可以使用任何预定义的功能吗?

我尝试过:

lazy val MODEL_RANK = Window.partitionBy(col(QUERY))     .orderBy(col(MODEL_SCORE).desc).rowsBetween(0,3)

但这给了我:

sql.AnalysisException: Window Frame ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING must match the required frame ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW;

此外,尝试使用.rowsBetween(-3,0),但这也给我错误:

org.apache.spark.sql.AnalysisException: Window Frame ROWS BETWEEN 3 PRECEDING AND CURRENT ROW must match the required frame ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW;

1 个答案:

答案 0 :(得分:2)

由于已经计算出count_rank,因此下一步是找到一种将行分组为一组的方法。可以完成以下操作:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

val ranked_data_grouped = ranked_data
  .withColumn("bucket", (($"count_rank" -1)/4).cast(IntegerType))

ranked_data_grouped如下所示:

+-----+-------------------+---------------------+------------------+----------+-------+
|query|similar_queries    |model_score          |count             |count_rank|bucket |
+-----+-------------------+---------------------+------------------+----------+-------+
|shirt|shoes              |0.0022062579255572733|575.0             |1         |0      |
|shirt|polo shirts for men|0.007706416273211698 |349.0             |2         |0      |      
|shirt|white shirt        |0.0037413097560623485|217.0             |3         |0      |
|shirt|watch              |0.001097496453284101 |212.0             |4         |0      |
|shirt|shorts             |0.002669621942466027 |200.0             |5         |1      |
|shirt|funny shirt        |0.0034038130658784866|189.0             |6         |1      |
|shirt|crop top           |9.065831060804897E-4 |173.0             |7         |1      |
|shirt|necklace           |6.694577024597908E-4 |151.0             |8         |1      |
|shirt|shirt womens       |0.0019435265241921438|136.0             |9         |2      |
|shirt|black shirt        |0.03264296242546658  |114.0             |10        |2      |
+-----+-------------------+---------------------+------------------+----------+-------+

现在,您要做的就是按bucket进行分区并按model_score进行排序:

val output = ranked_data_grouped
  .withColumn("finalRank", row_number().over(Window.partitionBy($"bucket").orderBy($"model_score".desc)))