在Spark数据帧udf中,类似struct(col1,col2)的函数参数的类型是什么?

时间:2018-03-10 08:49:37

标签: apache-spark apache-spark-sql apache-spark-dataset

背景:

我有一个包含三列的数据框:id, x, y。 x,y是Double。

  • 首先,我struct (col("x"),col("y"))获取坐标列。
  • 然后groupBy(col("id"))agg(collect_list(col("coordinate")))

所以现在df只有两列:id ,coordinate

我认为坐标的数据类型是collection.mutable.WrappedArray[(Double,Double)]。   所以我把它传给了udf。但是,数据类型是错误的。运行代码时出现错误。我不知道为什么。 struct(col1,col2)的真实数据类型是什么?或者是否有其他方法可以轻松获得正确的答案?

这是代码:

def getMedianPoint = udf((array1: collection.mutable.WrappedArray[(Double,Double)]) => {  
    var l = (array1.length/2)
    var c = array1(l)
    val x = c._1.asInstanceOf[Double]
    val y = c._2.asInstanceOf[Double]
    (x,y)
})

df.withColumn("coordinate",struct(col("x"),col("y")))
  .groupBy(col("id"))
  .agg(collect_list("coordinate").as("coordinate")
  .withColumn("median",getMedianPoint(col("coordinate")))

非常感谢!

1 个答案:

答案 0 :(得分:0)

  
    

我认为坐标的数据类型是collection.mutable.WrappedArray [(Double,Double)]

  

是的,你说是绝对正确的。 您在udf函数中定义为dataTypes的内容以及您作为参数传递的内容也是正确的。但主要问题是struct column的键名称。因为你必须得到以下问题

  
    

无法解析' UDF(坐标)'由于数据类型不匹配:参数1需要数组>但是,类型' coordinate'是数组>类型。;;

  

只需使用alias 将结构键重命名为

即可消除错误
df.withColumn("coordinate",struct(col("x").as("_1"),col("y").as("_2")))
  .groupBy(col("id"))
  .agg(collect_list("coordinate").as("coordinate"))
    .withColumn("median",getMedianPoint(col("coordinate")))

以便键名匹配。

但是

这将在

引起另一个问题
  var c = array1(l)
  
    

引起:java.lang.ClassCastException:org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema无法强制转换为scala.Tuple2

  

因此我建议您将udf功能更改为

import org.apache.spark.sql.functions._

def getMedianPoint = udf((array1: Seq[Row]) => {
  var l = (array1.length/2)
  (array1(l)(0).asInstanceOf[Double], array1(l)(1).asInstanceOf[Double])
})

因此,您甚至不需要使用alias。所以完整的解决方案将是

import org.apache.spark.sql.functions._

def getMedianPoint = udf((array1: Seq[Row]) => {
  var l = (array1.length/2)
  (array1(l)(0).asInstanceOf[Double], array1(l)(1).asInstanceOf[Double])
})

df.withColumn("coordinate",struct(col("x"),col("y")))
  .groupBy(col("id"))
  .agg(collect_list("coordinate").as("coordinate"))
    .withColumn("median",getMedianPoint(col("coordinate")))
  .show(false)

我希望答案很有帮助