我在Spark中有以下数据框(只有一行):
df.show
+---+---+---+---+---+---+
| A| B| C| D| E| F|
+---+---+---+---+---+---+
| 1|4.4| 2| 3| 7|2.6|
+---+---+---+---+---+---+
我想获取其值大于2.8的列(仅作为示例)。结果应该是:
List(B, D , E)
这是我自己的解决方案:
val cols = df.columns
val threshold = 2.8
val values = df.rdd.collect.toList
val res = values
.flatMap(x => x.toSeq)
.map(x => x.toString.toDouble)
.zip(cols)
.filter(x => x._1 > threshold)
.map(x => x._2)
答案 0 :(得分:2)
您可以使用explode
和array
函数:
df.select(
explode(
array(
df.columns.map(c => struct(lit(c).alias("key"), col(c).alias("val"))): _*
)
).as("kv")
)
.where($"kv.val" > 2.8)
.select($"kv.key")
.show()
+---+
|key|
+---+
| B|
| D|
| E|
+---+
然后您可以收集此结果。但我认为首先收集数据帧没有任何问题,因为t只有1行:
df.columns.zip(df.first().toSeq.map(_.asInstanceOf[Double]))
.collect{case (c,v) if v>2.8 => c} // Array(B,D,E)
答案 1 :(得分:2)
一个简单的udf
函数应为您提供正确的结果
val columns = df.columns
def getColumns = udf((cols: Seq[Double]) => cols.zip(columns).filter(_._1 > 2.8).map(_._2))
df.withColumn("columns > 2.8", getColumns(array(columns.map(col(_)): _*))).show(false)
因此,即使您有如下所示的多行
+---+---+---+---+---+---+
|A |B |C |D |E |F |
+---+---+---+---+---+---+
|1 |4.4|2 |3 |7 |2.6|
|4 |2.7|2 |3 |1 |2.9|
+---+---+---+---+---+---+
您将获得每行结果
+---+---+---+---+---+---+-------------+
|A |B |C |D |E |F |columns > 2.8|
+---+---+---+---+---+---+-------------+
|1 |4.4|2 |3 |7 |2.6|[B, D, E] |
|4 |2.7|2 |3 |1 |2.9|[A, D, F] |
+---+---+---+---+---+---+-------------+
我希望答案会有所帮助
答案 2 :(得分:1)
val c = df.columns.foldLeft(df){(a,b) => a.withColumn(b, when(col(b) > 2.8, b))}
c.collect
您可以从数组中删除空值