我有一个df:
joined.printSchema
root
|-- cc_num: long (nullable = true)
|-- lat: double (nullable = true)
|-- long: double (nullable = true)
|-- merch_lat: double (nullable = true)
|-- merch_long: double (nullable = true)
我有一个udf:
def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
val r : Int = 6371 //Earth radius
val latDistance : Double = Math.toRadians(lat2 - lat1)
val lonDistance : Double = Math.toRadians(lon2 - lon1)
val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) + Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
val distance : Double = r * c
distance
}
我需要使用以下命令为df生成一个新列:
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
我在下面收到错误消息:
Name: Unknown Error
Message: <console>:35: error: type mismatch;
found : String("lat")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("long")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("merch_lat")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("merch_long")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
从模式中可以看到,所有涉及的字段都是double
类型,它符合udf的参数类型定义,为什么我看到数据类型不匹配错误?
任何人都可以在这里启发问题和解决方法吗?
非常感谢您。
答案 0 :(得分:2)
您的getDistance
方法不是UDF,它是Scala方法,需要4个Double
参数,而您要传递4个字符串。
要解决此问题,您需要:
$
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession
val distanceUdf: UserDefinedFunction = udf(getDistance _)
joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))