以下是代码:
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Dataset, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
import vu.co.kaiyin.Utils.withRDD
Logger.getLogger("org").setLevel(Level.ERROR)
Logger.getLogger("akka").setLevel(Level.ERROR)
val conf = new SparkConf().setAppName("wordCount").setMaster("local[4]")
val sc = new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
case class PatientReadings(
val patientId: Int,
val heightCm: Int,
val weightKg: Int,
val age:Int,
val isSmoker:Boolean
)
val readings = List(
PatientReadings(1, 175, 72, 43, false),
PatientReadings(2, 182, 78, 28, true),
PatientReadings(3, 164, 61, 41, false),
PatientReadings(4, 161, 62, 43, true)
)
val df = sc.parallelize(readings).toDF()
df.show()
val heightM = df("heightCm") / 100
val bmi = df("weightKg") / (heightM * heightM)
bmi
在控制台中显示如下:
scala> bmi: org.apache.spark.sql.Column = (weightKg / ((heightCm / 100) * (heightCm / 100)))
显然,该部门进行了两次。我怎么能避免这个?
答案 0 :(得分:2)
您可以使用UDF
val bmiFunc = udf((heightCm: Double, weightKg: Double) => {
val heightM = heightCm / 100
weightKg / (heightM * heightM)
})
val bmi = bmiFunc(df("heightCm"), df("weightKg"))
或者,如果您需要单独heightM
和bmi
val heightM = df("heightCm") / 100
val bmiFunc = udf((heightM: Double, weightKg: Double) => {
weightKg / (heightM * heightM)
})
val bmi = bmiFunc(heightM, df("weightKg"))
答案 1 :(得分:1)
这些是自动生成的列名称,不应影响实际执行,但如果您担心,您始终可以使用pow
函数替换乘法:
import org.apache.spark.sql.functions.pow
$"weightKg" / pow($"heightCm" / 100, 2)