比较foldleft中的列值

时间:2018-01-10 14:39:34

标签: scala apache-spark apache-spark-sql

我正在尝试将DataFrame中的列汇总到一个新列中,该列将添加到数据帧本身。

这是DataFrame

val input = sc.parallelize(Seq(
  ("r1", 1, 1),
  ("r2", 6, 4),
  ("r3", 4, 1),
  ("r4", 1, 2)
)).toDF("ID", "a", "b")

我只想在“a”和“b”中添加一个列为“1”的列。

这是我提出的Scala代码,遗憾的是它为任何行返回0并且无法使其正常工作。任何帮助表示赞赏!

import org.apache.spark.sql.functions._
import sqlContext.implicits._
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{lit, col}

val columns = Seq("a", "b").map(col _)

def countOnes(cols: Column*) = cols.foldLeft(lit(0)){
  (cnt, current) => 
    if (current == 1) 
      cnt + 1
    else
      cnt
}

val output = input.withColumn("ones", countOnes(columns: _*))
output.show

预期结果是:

+---+---+---+----+
| ID|  a|  b|ones|
+---+---+---+----+
| r1|  1|  1|   2|
| r2|  6|  4|   0|
| r3|  4|  1|   1|
| r4|  1|  2|   1|
+---+---+---+----+

1 个答案:

答案 0 :(得分:3)

您可以使用reduce构建用于计算每行数量的列表达式,然后使用withColumn函数创建新列:

val ones = Seq("a", "b").map(x => when(col(x) === 1, 1).otherwise(0)).reduce(_ + _)

input.withColumn("ones", ones).show
+---+---+---+----+
| ID|  a|  b|ones|
+---+---+---+----+
| r1|  1|  1|   2|
| r2|  6|  4|   0|
| r3|  4|  1|   1|
| r4|  1|  2|   1|
+---+---+---+----+

或者,如果使用foldLeft,则需要when.otherwise代替if/else进行列操作:

def countOnes(cols: Column*) = cols.foldLeft(lit(0)){
    (cnt, current) => when(current === 1, cnt + 1).otherwise(cnt)
}

val output = input.withColumn("ones", countOnes(columns: _*))

output.show
+---+---+---+----+
| ID|  a|  b|ones|
+---+---+---+----+
| r1|  1|  1|   2|
| r2|  6|  4|   0|
| r3|  4|  1|   1|
| r4|  1|  2|   1|
+---+---+---+----+