为什么在Scala中调用UDF时数据类型会更改

时间:2019-03-13 19:07:44

标签: scala apache-spark apache-spark-sql

我有一个df:

joined.printSchema
root
 |-- cc_num: long (nullable = true)
 |-- lat: double (nullable = true)
 |-- long: double (nullable = true)
 |-- merch_lat: double (nullable = true)
 |-- merch_long: double (nullable = true)

我有一个udf:

def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
    val r : Int = 6371 //Earth radius
    val latDistance : Double = Math.toRadians(lat2 - lat1)
    val lonDistance : Double = Math.toRadians(lon2 - lon1)
    val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) + Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
    val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
    val distance : Double = r * c
    distance
  }

我需要使用以下命令为df生成一个新列:

joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))

我在下面收到错误消息:

Name: Unknown Error
Message: <console>:35: error: type mismatch;
 found   : String("lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                          ^
<console>:35: error: type mismatch;
 found   : String("long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                 ^
<console>:35: error: type mismatch;
 found   : String("merch_lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                         ^
<console>:35: error: type mismatch;
 found   : String("merch_long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                                      ^

从模式中可以看到,所有涉及的字段都是double类型,它符合udf的参数类型定义,为什么我看到数据类型不匹配错误?

任何人都可以在这里启发问题和解决方法吗?

非常感谢您。

1 个答案:

答案 0 :(得分:2)

您的getDistance方法不是UDF,它是Scala方法,需要4个Double参数,而您要传递4个字符串。

要解决此问题,您需要:

  • 使用UDF“包装”您的方法,并且
  • 应用UDF时,
  • 传递参数而不是字符串,您可以通过在列名前面加上$
  • 来实现。
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession

val distanceUdf: UserDefinedFunction = udf(getDistance _)

joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))