Spark RDD - 用其他列的平均值替换缺失的列

时间:2017-09-27 18:48:26

标签: scala apache-spark rdd


RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, , 2), (003, 4, , 9, 2), (003, 2, 3, 0, 1) )



RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, (1+3)/3 , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, (5+9+0)/4 , 2), (003, 4, (4+4+3)/4 , 9, 2), (003, 2, 3, 0, 1) )


2 个答案:

答案 0 :(得分:2)

// Create the exact input data provided as a Spark DataFrame/DataSet
val df = {
  import org.apache.spark.sql._
  import org.apache.spark.sql.types._
  import scala.collection.JavaConverters._

  val simpleSchema = StructType(
    StructField("a", StringType) ::
    StructField("b", IntegerType) ::
    StructField("c", IntegerType) ::
    StructField("d", IntegerType) ::
    StructField("e", IntegerType) :: Nil)

  val data = List(
    Row("001", 1, 0, 3, 4),
    Row("001", 3, 4, 1, 7),
    Row("001", null, 0, 6, 4),
    Row("003", 1, 4, 5, 7),
    Row("003", 5, 4, null, 2),
    Row("003", 4, null, 9, 2),
    Row("003", 2, 3, 0, 1)

  spark.createDataFrame(data.asJava, simpleSchema)

// fill replaces nulls with zero, which we need for the desired averaging.    
val avgs ="a")).avg("b", "c", "d", "e").as("avgs")

val output ="df").join(avgs, col("df.a") === col("avgs.a")).select(col("df.a"),
  coalesce(col("df.b"), col("avg(b)")),
  coalesce(col("df.c"), col("avg(c)")),
  coalesce(col("df.d"), col("avg(d)")),
  coalesce(col("df.e"), col("avg(e)"))

|  a|   b|   c|   d|  e|
|001|   1|   0|   3|  4|
|001|   3|   4|   1|  7|
|001|null|   0|   6|  4|
|003|   1|   4|   5|  7|
|003|   5|   4|null|  2|
|003|   4|null|   9|  2|
|003|   2|   3|   0|  1|

|  a|            avg(b)|            avg(c)|            avg(d)|avg(e)|
|003|               3.0|              2.75|               3.5|   3.0|
|001|1.3333333333333333|1.3333333333333333|3.3333333333333335|   5.0|

|  a|coalesce(df.b, avg(b))|coalesce(df.c, avg(c))|coalesce(df.d, avg(d))|coalesce(df.e, avg(e))|
|001|                   1.0|                   0.0|                   3.0|                   4.0|
|001|                   3.0|                   4.0|                   1.0|                   7.0|
|001|    1.3333333333333333|                   0.0|                   6.0|                   4.0|
|003|                   1.0|                   4.0|                   5.0|                   7.0|
|003|                   5.0|                   4.0|                   3.5|                   2.0|
|003|                   4.0|                  2.75|                   9.0|                   2.0|
|003|                   2.0|                   3.0|                   0.0|                   1.0|

答案 1 :(得分:1)


val df = Seq(
  ("001", Some(1), Some(0), Some(3), Some(4)),
  ("001", Some(3), Some(4), Some(1), Some(7)),
  ("001", None, Some(0), Some(6), Some(4)),
  ("003", Some(1), Some(4), Some(5), Some(7)),
  ("003", Some(5), Some(4), None, Some(2)),
  ("003", Some(4), None, Some(9), Some(2)),
  ("003", Some(2), Some(3), Some(0), Some(1))

spark.sql("""  select a, coalesce(b,sum(b) over(partition by a)/count(*) over(partition by a)) b1, coalesce( c, sum(c) over(partition by a)/count(*) over(partition by a)) c1,
            coalesce( d, sum(d) over(partition by a)/count(*) over(partition by a)) d1, coalesce( e, sum(e) over(partition by a)/count(*) over(partition by a)) e1 from avg_temp


|a  |b   |c   |d   |e  |
|001|1   |0   |3   |4  |
|001|3   |4   |1   |7  |
|001|null|0   |6   |4  |
|003|1   |4   |5   |7  |
|003|5   |4   |null|2  |
|003|4   |null|9   |2  |
|003|2   |3   |0   |1  |
|a  |b1                |c1  |d1 |e1 |
|003|1.0               |4.0 |5.0|7.0|
|003|5.0               |4.0 |3.5|2.0|
|003|4.0               |2.75|9.0|2.0|
|003|2.0               |3.0 |0.0|1.0|
|001|1.0               |0.0 |3.0|4.0|
|001|3.0               |4.0 |1.0|7.0|
|001|1.3333333333333333|0.0 |6.0|4.0|