我正在尝试将DataFrame中的列汇总到一个新列中,该列将添加到数据帧本身。
这是DataFrame
val input = sc.parallelize(Seq(
("r1", 1, 1),
("r2", 6, 4),
("r3", 4, 1),
("r4", 1, 2)
)).toDF("ID", "a", "b")
我只想在“a”和“b”中添加一个列为“1”的列。
这是我提出的Scala代码,遗憾的是它为任何行返回0并且无法使其正常工作。任何帮助表示赞赏!
import org.apache.spark.sql.functions._
import sqlContext.implicits._
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{lit, col}
val columns = Seq("a", "b").map(col _)
def countOnes(cols: Column*) = cols.foldLeft(lit(0)){
(cnt, current) =>
if (current == 1)
cnt + 1
else
cnt
}
val output = input.withColumn("ones", countOnes(columns: _*))
output.show
预期结果是:
+---+---+---+----+
| ID| a| b|ones|
+---+---+---+----+
| r1| 1| 1| 2|
| r2| 6| 4| 0|
| r3| 4| 1| 1|
| r4| 1| 2| 1|
+---+---+---+----+
答案 0 :(得分:3)
您可以使用reduce
构建用于计算每行数量的列表达式,然后使用withColumn
函数创建新列:
val ones = Seq("a", "b").map(x => when(col(x) === 1, 1).otherwise(0)).reduce(_ + _)
input.withColumn("ones", ones).show
+---+---+---+----+
| ID| a| b|ones|
+---+---+---+----+
| r1| 1| 1| 2|
| r2| 6| 4| 0|
| r3| 4| 1| 1|
| r4| 1| 2| 1|
+---+---+---+----+
或者,如果使用foldLeft
,则需要when.otherwise
代替if/else
进行列操作:
def countOnes(cols: Column*) = cols.foldLeft(lit(0)){
(cnt, current) => when(current === 1, cnt + 1).otherwise(cnt)
}
val output = input.withColumn("ones", countOnes(columns: _*))
output.show
+---+---+---+----+
| ID| a| b|ones|
+---+---+---+----+
| r1| 1| 1| 2|
| r2| 6| 4| 0|
| r3| 4| 1| 1|
| r4| 1| 2| 1|
+---+---+---+----+