我在Scala中有一个Spark数据帧,如下所示 -
val df = Seq(
(0,0,0,0.0,0),
(1,0,0,0.1,1),
(0,1,0,0.11,1),
(0,0,1,0.12,1),
(1,1,0,0.24,2),
(1,0,1,0.27,2),
(0,1,1,0.3,2),
(1,1,1,0.4,3)
).toDF("A","B","C","rate","total")
以下是它的样子
scala> df.show
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 0| 0| 0| 0.0| 0|
| 1| 0| 0| 0.1| 1|
| 0| 1| 0|0.11| 1|
| 0| 0| 1|0.12| 1|
| 1| 1| 0|0.24| 2|
| 1| 0| 1|0.27| 2|
| 0| 1| 1| 0.3| 2|
| 1| 1| 1| 0.4| 3|
+---+---+---+----+-----+
在这种情况下,A,B和C是通道。 0和1分别表示通道的缺失和存在。 2 ^ 3显示数据帧中的8种组合,其中“total”列给出了这3个通道的行方和。
这些频道出现的个别概率可以通过 -
给出scala> val oneChannelCase = df.filter($"total" === 1).toDF()
scala> oneChannelCase.show()
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 1| 0| 0| 0.1| 1|
| 0| 1| 0|0.11| 1|
| 0| 0| 1|0.12| 1|
+---+---+---+----+-----+
但是,我只对这些通道的成对概率感兴趣,这些概率由 -
给出scala> val probs = df.filter($"total" === 2).toDF()
scala> probs.show()
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 1| 1| 0|0.24| 2|
| 1| 0| 1|0.27| 2|
| 0| 1| 1| 0.3| 2|
+---+---+---+----+-----+
我想做的是 - 在这些显示个体概率的“probs”数据框中添加3个新列。以下是我要找的输出 -
A B C rate prob_A prob_B prob_C
1 1 0 0.24 0.1 0.11 0
1 0 1 0.27 0.1 0 0.12
0 1 1 0.3 0 0.11 0.12
为了使事情更清楚,输出结果的第一行显示A = 1,B = 1,C = 0。因此,A = 0.1,B = 0.11和C = 0的各个概率分别附加到probs数据帧。类似地,对于第二行,A = 1,B = 0,C = 1表示A = 0.1,B = 0和C = 0.12的个体概率分别附加到probs数据帧。
这是我尝试过的 -
scala> val channels = df.columns.filter(v => !(v.contains("rate") | v.contains("total")))
#channels: Array[String] = Array(A, B, C)
scala> val pivotedProb = channels.map(v => f"case when $v = 1 then rate else 0 end as prob_${v}")
scala> val param = pivotedProb.mkString(",")
scala> val probs = spark.sql(f"select *, $param from df")
scala> probs.show()
+---+---+---+----+-----+------+------+------+
| A| B| C|rate|total|prob_A|prob_B|prob_C|
+---+---+---+----+-----+------+------+------+
| 0| 0| 0| 0.0| 0| 0.0| 0.0| 0.0|
| 1| 0| 0| 0.1| 1| 0.1| 0.0| 0.0|
| 0| 1| 0|0.11| 1| 0.0| 0.11| 0.0|
| 0| 0| 1|0.12| 1| 0.0| 0.0| 0.12|
| 1| 1| 0|0.24| 2| 0.24| 0.24| 0.0|
| 1| 0| 1|0.27| 2| 0.27| 0.0| 0.27|
| 0| 1| 1| 0.3| 2| 0.0| 0.3| 0.3|
| 1| 1| 1| 0.4| 3| 0.4| 0.4| 0.4|
+---+---+---+----+-----+------+------+------+
这给了我错误的输出。
请帮助。
答案 0 :(得分:2)
如果我正确理解您的要求,使用>
遍历频道列,您可以1)从单频道数据帧生成foldLeft
,并且2)向两个频道添加列 - 通道数据帧,其列值等于通道和对应的ratesMap
值的乘积:
ratesMap