避免在Spark中重新计算新列

时间:2016-07-01 15:41:50

标签: scala apache-spark apache-spark-sql

以下是代码:

  import org.apache.log4j.{Level, Logger}
  import org.apache.spark.sql.{Dataset, SQLContext}
  import org.apache.spark.{SparkConf, SparkContext}
  import vu.co.kaiyin.Utils.withRDD
  Logger.getLogger("org").setLevel(Level.ERROR)
  Logger.getLogger("akka").setLevel(Level.ERROR)
  val conf = new SparkConf().setAppName("wordCount").setMaster("local[4]")
  val sc = new SparkContext(conf)
  val sqlContext = new org.apache.spark.sql.SQLContext(sc)
  import sqlContext.implicits._

  case class PatientReadings(
                              val patientId: Int,
                              val heightCm: Int,
                              val weightKg: Int,
                              val age:Int,
                              val isSmoker:Boolean
                            )

  val readings = List(
    PatientReadings(1, 175, 72, 43, false),
    PatientReadings(2, 182, 78, 28, true),
    PatientReadings(3, 164, 61, 41, false),
    PatientReadings(4, 161, 62, 43, true)
  )
  val df = sc.parallelize(readings).toDF()
  df.show()

  val heightM = df("heightCm") / 100
  val bmi = df("weightKg") / (heightM * heightM)

bmi在控制台中显示如下:

    scala> bmi: org.apache.spark.sql.Column = (weightKg / ((heightCm / 100) * (heightCm / 100)))

显然,该部门进行了两次。我怎么能避免这个?

2 个答案:

答案 0 :(得分:2)

您可以使用UDF

val bmiFunc = udf((heightCm: Double, weightKg: Double) => {
  val heightM = heightCm / 100
  weightKg / (heightM * heightM)
})
val bmi = bmiFunc(df("heightCm"), df("weightKg"))

或者,如果您需要单独heightMbmi

val heightM = df("heightCm") / 100
val bmiFunc = udf((heightM: Double, weightKg: Double) => {
  weightKg / (heightM * heightM)
})
val bmi = bmiFunc(heightM, df("weightKg"))

答案 1 :(得分:1)

这些是自动生成的列名称,不应影响实际执行,但如果您担心,您始终可以使用pow函数替换乘法:

import org.apache.spark.sql.functions.pow 

$"weightKg" / pow($"heightCm" / 100, 2)