下面是我的Spark函数,它很简单
def doubleToRound(df:DataFrame,roundColsList:Array[String]): DataFrame ={
var y:DataFrame = df
for(colDF <- y.columns){
if(roundColsList.contains(colDF)){
y = y.withColumn(colDF,functions.round(y.col(colDF),2))
}
}
通过按给定DF的多列值将十进制值舍入到2位,这可以按预期工作。但是我遍历DataFrame y直到列Array [Sting] .length()。 还有其他更好的方法吗?
谢谢大家
答案 0 :(得分:3)
您可以简单地将select
与map
一起使用,如以下示例所示:
import org.apache.spark.sql.functions._
import spark.implicits._
val df = Seq(
("a", 1.22, 2.333, 3.4444),
("b", 4.55, 5.666, 6.7777)
).toDF("id", "v1", "v2", "v3")
val roundCols = df.columns.filter(_.startsWith("v")) // Or filter with other conditions
val otherCols = df.columns diff roundCols
df.select(otherCols.map(col) ++ roundCols.map(c => round(col(c), 2).as(c)): _*).show
// +---+----+----+----+
// | id| v1| v2| v3|
// +---+----+----+----+
// | a|1.22|2.33|3.44|
// | b|4.55|5.67|6.78|
// +---+----+----+----+
为其提供一种方法:
import org.apache.spark.sql.DataFrame
def doubleToRound(df: DataFrame, roundCols: Array[String]): DataFrame = {
val otherCols = df.columns diff roundCols
df.select(otherCols.map(col) ++ roundCols.map(c => round(col(c), 2).as(c)): _*)
}
或者,按如下方式使用foldLeft
和withColumn
:
def doubleToRound(df: DataFrame, roundCols: Array[String]): DataFrame =
roundCols.foldLeft(df)((acc, c) => acc.withColumn(c, round(col(c), 2)))