火花在满足条件的列中获得最小值

时间:2018-11-19 19:48:02

标签: scala apache-spark dataframe

我的Spark中有一个DataFrame,看起来像这样:

id |  flag
----------
 0 |  true
 1 |  true
 2 | false
 3 |  true
 4 |  true
 5 |  true
 6 | false
 7 | false
 8 |  true
 9 | false

如果它具有flag == false或下一个假值的rowNumber,我想获得另一个具有当前rowNumber的列,因此输出将如下所示:

id |  flag | nextOrCurrentFalse
-------------------------------
 0 |  true |                  2
 1 |  true |                  2
 2 | false |                  2
 3 |  true |                  6
 4 |  true |                  6
 5 |  true |                  6
 6 | false |                  6
 7 | false |                  7
 8 |  true |                  9
 9 | false |                  9

我想以向量化的方式进行此操作(而不是逐行迭代)。所以我实际上希望逻辑是:

  • 对于每一行,获取的最小ID大于或等于当前行号,该行号具有== false标记

3 个答案:

答案 0 :(得分:2)

如果flag相当稀疏,您可以这样做:

val ids = df.where("flag = false"). 
             select($"id".as("id1"))  

val withNextFalse = df.join(ids, df("id") <= ids("id1")).
                      groupBy("id", "flag").
                      agg("id1" -> "min")

第一步,我们创建一个id为id的数据帧,其中标志为false。然后,我们在所需条件下将该数据帧连接到原始数据(原始ID应当小于或等于flag为false的行的ID)。

要获得 first 这种情况,请按ID分组并使用agg查找id1的最小值(这是带有flag = false的行的id

运行示例数据(并按ID排序)可提供所需的输出:

+---+-----+--------+
| id| flag|min(id1)|
+---+-----+--------+
|  0| true|       2|
|  1| true|       2|
|  2|false|       2|
|  3| true|       6|
|  4| true|       6|
|  5| true|       6|
|  6|false|       6|
|  7|false|       7|
|  8| true|       9|
|  9|false|       9|
+---+-----+--------+

如果DataFrame很大并且有很多行的标记为False,则此方法可能会遇到性能问题。如果是这种情况,使用迭代解决方案可能会更好。

答案 1 :(得分:2)

考虑过扩展等问题-但不清楚Catalyst是否足够好-我提出了一种解决方案,该解决方案基于可以从分区中受益并且要做的工作少得多的答案之一-只需考虑数据即可。它涉及预计算和处理,这一点可以使某些按摩击败暴力手段。您对JOIN的观点已不再是问题,因为这是目前有限的JOIN,并且没有大量数据生成。

您对数据框方法的评论有些不高兴,因为这里超出的只是数据框。我认为您的意思是您想循环通过一个数据框,并有一个带有出口的子循环。我找不到这样的例子,实际上我不确定它是否适合SPARK范例。获得相同的结果,但处理更少:

import org.apache.spark.sql.functions._
import spark.implicits._
import org.apache.spark.sql.expressions.Window

val df = Seq((0, true), (1, true), (2,false), (3, true), (4,true), (5,true), (6,false), (7,false), (8,true), (9,false)).toDF("id","flag")
@transient val  w1 = org.apache.spark.sql.expressions.Window.orderBy("id1")  

val ids = df.where("flag = false") 
            .select($"id".as("id1"))  

val ids2 = ids.select($"*", lag("id1",1,-1).over(w1).alias("prev_id"))
val ids3 = ids2.withColumn("prev_id1", col("prev_id")+1).drop("prev_id")

// Less and better performance at scale, this is better theoretically for Catalyst to bound partitions? Less work to do in any event.
// Some understanding of data required! And no grouping and min.
val withNextFalse = df.join(ids3, df("id") >= ids3("prev_id1") && df("id") <= ids3("id1"))
                     .select($"id", $"flag", $"id1".alias("nextOrCurrentFalse"))
                     .orderBy(asc("id"),asc("id"))

withNextFalse.show(false)

还返回:

+---+-----+------------------+
|id |flag |nextOrCurrentFalse|
+---+-----+------------------+
|0  |true |2                 |
|1  |true |2                 |
|2  |false|2                 |
|3  |true |6                 |
|4  |true |6                 |
|5  |true |6                 |
|6  |false|6                 |
|7  |false|7                 |
|8  |true |9                 |
|9  |false|9                 |
+---+-----+------------------+

答案 2 :(得分:0)

请参见其他更好的答案,但出于SQL教学目的,可能将其保留在此处。

这可以满足您的要求,但是我很想知道其他人对此有何看法。我将检查Catalyst并查看其程序上的工作方式,但是我认为这可能会导致分区界限丢失,我也希望对此进行检查。

import org.apache.spark.sql.functions._
val df = Seq((0, true), (1, true), (2,false), (3, true), (4,true), (5,true), (6,false), (7,false), (8,true), (9,false)).toDF("id","flag")
df.createOrReplaceTempView("tf") 

// Performance? Need to check at some stage how partitioning works in such a case.
spark.sql("CACHE TABLE tf") 
val res1 = spark.sql("""  
                       SELECT tf1.*, tf2.id as id2, tf2.flag as flag2
                         FROM tf tf1, tf tf2
                        WHERE tf2.id  >= tf1.id
                          AND tf2.flag = false 
                     """)    

//res1.show(false)
res1.createOrReplaceTempView("res1") 
spark.sql("CACHE TABLE res1") 

val res2 = spark.sql(""" SELECT X.id, X.flag, X.id2 
                           FROM (SELECT *, RANK() OVER (PARTITION BY id ORDER BY id2 ASC) as rank_val 
                                   FROM res1) X
                          WHERE X.rank_val = 1
                       ORDER BY id
                    """) 

res2.show(false)