用Scala-Spark中的行的平均值填充Nan

时间:2017-04-05 10:43:57

标签: scala apache-spark

我有一个RDD有6列,其中最后5列可能包含NaN。我的目的是将NaNs替换为不是Nan的行的最后5个值的其余值的平均值。例如,有这个输入:

1, 2, 3, 4, 5, 6
2, 2, 2, NaN, 4, 0
3, NaN, NaN, NaN, 6, 0
4, NaN, NaN, 4, 4, 0 

输出应为:

1, 2, 3, 4, 5, 6
2, 2, 2, 2, 4, 0
3, 3, 3, 3, 6, 0
4, 3, 3, 4, 4, 0

我知道如何使用将RDD转换为DataFrame的列的平均值来填充这些NaN:

var aux1 = df.select(df.columns.map(c => mean(col(c))) :_*)
var aux2 = df.na.fill(/*get values of aux1*/)

我的问题是,你怎么能做这个操作,而不是用列平均值填充NaN,用它的一个子组的平均值填充它?

3 个答案:

答案 0 :(得分:2)

你可以通过定义一个函数来获得平均值,另一个函数来填充一行中的空值。

鉴于您提出的DF:

Schema::create('Users', function (Blueprint $table) {
            $table->increments('id');
            $table->string('username')->unique();
            $table->string('password');
            $table->boolean('isactive');
            $table->timestamps();
        });

我们需要一个函数来获得行的平均值:

val df = sc.parallelize(List((Some(1),Some(2),Some(3),Some(4),Some(5),Some(6)),(Some(2),Some(2),Some(2),None,Some(4),Some(0)),(Some(3),None,None,None,Some(6),Some(0)),(Some(4),None,None,Some(4),Some(4),Some(0)))).toDF("a","b","c","d","e","f")

另一个填充行中的空值:

import org.apache.spark.sql.Row
def rowMean(row: Row): Int = {
   val nonNulls = (0 until row.length).map(i => (!row.isNullAt(i), row.getAs[Int](i))).filter(_._1).map(_._2).toList
   nonNulls.sum / nonNulls.length
}

现在我们可以先计算每一行的意思:

def rowFillNulls(row: Row, fill: Int): Row = {
   Row((0 until row.length).map(i => if (row.isNullAt(i)) fill else row.getAs[Int](i)) : _*)
}

然后填写:

val rowWithMean = df.map(row => (row,rowMean(row)))

最后查看之前和之后...

val result = sqlContext.createDataFrame(rowWithMean.map{case (row,mean) => rowFillNulls(row,mean)}, df.schema)

这适用于任何带有Int列的宽度DF。您可以轻松地将其更新为其他数据类型,甚至是非数字类型(提示,检查df架构!)

答案 1 :(得分:1)

嗯,这是一个有趣的小问题 - 我会发布我的解决方案,但我肯定会观察并看看是否有人想出更好的方法:)

首先,我将介绍几个udf s:

val avg = udf((values: Seq[Integer]) => {
  val notNullValues = values.filter(_ != null).map(_.toInt)
  notNullValues.sum/notNullValues.length
})

val replaceNullWithAvg = udf((x: Integer, avg: Integer) => if(x == null) avg else x)

然后我将这样应用于DataFrame

dataframe
  .withColumn("avg", avg(array(df.columns.tail.map(s => df.col(s)):_*)))
  .select('col1, replaceNullWithAvg('col2, 'avg) as "col2", replaceNullWithAvg('col3, 'avg) as "col3", replaceNullWithAvg('col4, 'avg) as "col4", replaceNullWithAvg('col5, 'avg) as "col5", replaceNullWithAvg('col6, 'avg) as "col6")

这将为您提供所需的信息,但可能不是我曾经汇集过的最复杂的代码......

答案 2 :(得分:1)

一堆进口:

import org.apache.spark.sql.functions.{col, isnan, isnull, round, when}
import org.apache.spark.sql.Column

一些辅助功能:

def nullOrNan(c: Column) = isnan(c) || isnull(c)

def rowMean(cols: Column*): Column = {
  val sum = cols
    .map(c => when(nullOrNan(c), lit(0.0)).otherwise(c))
    .fold(lit(0.0))(_ + _)
  val count = cols
    .map(c => when(nullOrNan(c), lit(0.0)).otherwise(lit(1.0)))
    .fold(lit(0.0))(_ + _)
  sum / count
}

解决方案:

val mean = round(
  rowMean(df.columns.tail.map(col): _*)
).cast("int").alias("mean")

val exprs = df.columns.tail.map(
  c => when(nullOrNan(col(c)), mean).otherwise(col(c)).alias(c)
)

val filled = df.select(col(df.columns(0)) +: exprs: _*)