识别窗口(Scala)上一列的重复值

时间:2019-03-21 16:14:25

标签: scala apache-spark-sql user-defined-functions window-functions distinct-values

我有一个包含两列的数据框:“ ID”和“金额”,每一行代表特定ID和交易金额的交易。我的示例使用以下DF:

val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),
  (2, 50),(1, 30),(2, 120))).toDF("ID","Amount")

我想创建一个新列,以标识所述金额是否为重复值,即是否发生在其他任何使用相同ID的交易中。

我找到了一种更通用的方法,即使用以下功能在整个“金额”列中不考虑ID:

def recurring_amounts(df: DataFrame, col: String) : DataFrame = {
  var df_to_arr = df.select(col).rdd.map(r => r(0).asInstanceOf[Double]).collect()
  var arr_to_map = df_to_arr.groupBy(identity).mapValues(_.size)
  var map_to_df = arr_to_map.toSeq.toDF(col, "Count")
  var df_reformat = map_to_df.withColumn("Amount", $"Amount".cast(DoubleType))
  var df_out = df.join(df_reformat, Seq("Amount"))
  return df_new
}

val df_output = recurring_amounts(df, "Amount")

这将返回:

+---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120  |  3  |
| 1 | 120  |  3  |
| 2 |  40  |  1  |
| 2 |  50  |  1  | 
| 1 |  30  |  1  |
| 2 | 120  |  3  |
+---+------+-----+

然后我可以用它创建所需的二进制变量以指示该金额是否重复发生(如果> 1,则为是,否则为)。

但是,在此示例中,我的问题由值120表示,该值在ID 1而不是ID 2上重复出现。因此,我的期望输出是:

 +---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120  |  2  |
| 1 | 120  |  2  |
| 2 |  40  |  1  |
| 2 |  50  |  1  | 
| 1 |  30  |  1  |
| 2 | 120  |  1  |
+---+------+-----+

我一直在尝试一种方法来使用 .over(Window.partitionBy("ID"),但不确定如何执行。任何提示将不胜感激。

1 个答案:

答案 0 :(得分:0)

如果您精通sql,则可以为Dataframe编写sql查询。您需要做的第一件事是将Dataframe作为表格注册到Spark的内存中。之后,您可以在表顶部编写sql。请注意,spark是spark会话变量。

val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),(2, 50),(1, 30),(2, 120))).toDF("ID","Amount")
df.registerTempTable("transactions")
spark.sql("select *,count(*) over(partition by ID,Amount) as Count from transactions").show()

enter image description here

如有任何疑问,请告诉我。