根据Map替换Spark Dataframe中的值

时间:2017-12-12 22:48:29

标签: scala apache-spark

我有一个整数数据集,其中一些是真实数据,其中一些超过一定的阈值是错误代码。我还有一个列名称映射到其错误代码范围的开头。我想使用此映射有条件地替换值,例如,如果每列中的行的值高于错误范围的开头,则为None。

val errors = Map("Col_1" -> 100, "Col_2" -> 10)

val df = Seq(("john", 1, 100), ("jacob", 10, 100), ("heimer", 1000, 
1)).toDF("name", "Col_1", "Col_2")

df.take(3)
// name   | Col_1 | Col_2
// john   | 1     | 1
// jacob  | 10    | 10
// heimer | 1000  | 1

//create some function like this
def fixer = udf((column_value, column_name) => {
    val crit_val = errors(column_name)
    if(column_value >= crit_val) {
        None
    } else {
        column_value
    }
}

//apply it in some way
val fixed_df = df.columns.map(_ -> fixer(_))

//to get output like this:
fixed_df.take(3)
// name   | Col_1 | Col_2
// john   | 1     | 1
// jacob  | 10    | None
// heimer | None  | 1

1 个答案:

答案 0 :(得分:3)

使用UDF执行此操作并不太方便 - UDF需要特定列(或多个列)并返回单个列,此处您需要处理各种不同的列。此外,可以使用Spark的内置方法when执行检查阈值并使用某个常量替换值的操作,并且不需要UDF。

所以,这里有一种方法可以为每个具有一定阈值的列使用when,从而迭代地遍历相关列并生成所需的DataFrame(我们将替换"更糟糕的"坏& #34;值null):

import org.apache.spark.sql.functions._
import spark.implicits._

// fold the list of errors, replacing the original column
// with a "corrected" column with same name in each iteration
val newDf = errors.foldLeft(df) { case (tmpDF, (colName, threshold)) =>
  tmpDF.withColumn(colName, when($"$colName" > threshold, null).otherwise($"$colName"))
}

newDf.show()
// +------+-----+-----+
// |  name|Col_1|Col_2|
// +------+-----+-----+
// |  john|    1|    1|
// | jacob|   10| null|
// |heimer| null|    1|
// +------+-----+-----+