过滤Spark数据帧中的多个列

时间:2018-02-16 06:27:05

标签: scala apache-spark spark-dataframe

假设我在Spark中有一个数据框,如下所示 -

val df = Seq(
(0,0,0,0.0),
(1,0,0,0.1),
(0,1,0,0.11),
(0,0,1,0.12),
(1,1,0,0.24),
(1,0,1,0.27),
(0,1,1,0.30),
(1,1,1,0.40)
).toDF("A","B","C","rate")

以下是它的样子 -

scala> df.show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  0|  0|  0| 0.0|
|  1|  0|  0| 0.1|
|  0|  1|  0|0.11|
|  0|  0|  1|0.12|
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
|  1|  1|  1| 0.4|
+---+---+---+----+

A,B和C是这种情况下的广告渠道。 0和1分别表示通道的缺失和存在。 2 ^ 3显示了数据框中的8种组合。

我想过滤此数据框中的记录,一次显示2个频道(AB,AC,BC)。这就是我想要输出的方式 -

+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
+---+---+---+----+

我可以通过执行 -

编写3个语句来获取输出
scala> df.filter($"A" === 1 && $"B" === 1 && $"C" === 0).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
+---+---+---+----+


scala> df.filter($"A" === 1 && $"B" === 0  && $"C" === 1).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  0|  1|0.27|
+---+---+---+----+


scala> df.filter($"A" === 0 && $"B" === 1 && $"C" === 1).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  0|  1|  1| 0.3|
+---+---+---+----+

但是,我想使用一个完成我的工作的语句或一个帮助我获得输出的函数来实现这一点。 我在考虑使用case语句来匹配值。但一般来说,我的数据框可能包含3个以上的通道 -

scala> df.show()
+---+---+---+---+----+
|  A|  B|  C|  D|rate|
+---+---+---+---+----+
|  0|  0|  0|  0| 0.0|
|  0|  0|  0|  1| 0.1|
|  0|  0|  1|  0| 0.1|
|  0|  0|  1|  1|0.59|
|  0|  1|  0|  0| 0.1|
|  0|  1|  0|  1|0.89|
|  0|  1|  1|  0|0.39|
|  0|  1|  1|  1| 0.4|
|  1|  0|  0|  0| 0.0|
|  1|  0|  0|  1|0.99|
|  1|  0|  1|  0|0.49|
|  1|  0|  1|  1| 0.1|
|  1|  1|  0|  0|0.79|
|  1|  1|  0|  1| 0.1|
|  1|  1|  1|  0| 0.1|
|  1|  1|  1|  1| 0.1|
+---+---+---+---+----+

在这种情况下,我希望输出为 -

scala> df.show()
+---+---+---+---+----+
|  A|  B|  C|  D|rate|
+---+---+---+---+----+
|  0|  0|  1|  1|0.59|
|  0|  1|  0|  1|0.89|
|  0|  1|  1|  0|0.39|
|  1|  0|  0|  1|0.99|
|  1|  0|  1|  0|0.49|
|  1|  1|  0|  0|0.79|
+---+---+---+---+----+

表示配对存在的通道的速率=> (AB,AC,AD,BC,BD,CD)。

请帮助。

1 个答案:

答案 0 :(得分:3)

一种方法是对列进行求和,然后仅在和的结果为2时进行过滤。

import org.apache.spark.sql.functions._

df.withColumn("res", $"A" + $"B" + $"C").filter($"res" === lit(2)).drop("res").show

输出结果为:

+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
+---+---+---+----+