如何基于Spark Dataframe中的行值获取列的子集?

时间:2018-08-23 13:55:14

标签: apache-spark apache-spark-sql

我在Spark中有以下数据框(只有一行):

df.show
+---+---+---+---+---+---+
|  A|  B|  C|  D|  E|  F|
+---+---+---+---+---+---+
|  1|4.4|  2|  3|  7|2.6|
+---+---+---+---+---+---+

我想获取其值大于2.8的列(仅作为示例)。结果应该是:

List(B, D , E)

这是我自己的解决方案:

val cols = df.columns
val threshold = 2.8
val values = df.rdd.collect.toList
val res = values
         .flatMap(x => x.toSeq)
         .map(x => x.toString.toDouble)
         .zip(cols)
         .filter(x => x._1 > threshold)
         .map(x => x._2)

3 个答案:

答案 0 :(得分:2)

您可以使用explodearray函数:

df.select(
    explode(
      array(
        df.columns.map(c => struct(lit(c).alias("key"), col(c).alias("val"))): _*
      )
    ).as("kv")
  )
  .where($"kv.val" > 2.8)
  .select($"kv.key")
  .show()

+---+
|key|
+---+
|  B|
|  D|
|  E|
+---+

然后您可以收集此结果。但我认为首先收集数据帧没有任何问题,因为t只有1行:

df.columns.zip(df.first().toSeq.map(_.asInstanceOf[Double]))
      .collect{case (c,v) if v>2.8 => c} // Array(B,D,E)

答案 1 :(得分:2)

一个简单的udf函数应为您提供正确的结果

val columns = df.columns

def getColumns = udf((cols: Seq[Double]) => cols.zip(columns).filter(_._1 > 2.8).map(_._2))

df.withColumn("columns > 2.8", getColumns(array(columns.map(col(_)): _*))).show(false)

因此,即使您有如下所示的多行

+---+---+---+---+---+---+
|A  |B  |C  |D  |E  |F  |
+---+---+---+---+---+---+
|1  |4.4|2  |3  |7  |2.6|
|4  |2.7|2  |3  |1  |2.9|
+---+---+---+---+---+---+

您将获得每行结果

+---+---+---+---+---+---+-------------+
|A  |B  |C  |D  |E  |F  |columns > 2.8|
+---+---+---+---+---+---+-------------+
|1  |4.4|2  |3  |7  |2.6|[B, D, E]    |
|4  |2.7|2  |3  |1  |2.9|[A, D, F]    |
+---+---+---+---+---+---+-------------+

我希望答案会有所帮助

答案 2 :(得分:1)

val c = df.columns.foldLeft(df){(a,b) =>  a.withColumn(b, when(col(b) > 2.8, b))}
c.collect

您可以从数组中删除空值