遍历Scala列的元素

时间:2018-08-28 02:38:39

标签: scala apache-spark apache-spark-sql

我有一个由两个Doubles数组组成的数据框。我想创建一个新列,该列是对前两列应用欧几里德距离函数的结果,即如果我有:

 A      B 
(1,2)  (1,3)
(2,3)  (3,4)

创建:

 A      B     C
(1,2)  (1,3)  1
(2,3)  (3,4)  1.4

我的数据模式是:

df.schema.foreach(println)
StructField(col1,ArrayType(DoubleType,false),false)
StructField(col2,ArrayType(DoubleType,false),true)

每当我调用此距离函数时:

def distance(xs: Array[Double], ys: Array[Double]) = {
  sqrt((xs zip ys).map { case (x,y) => pow(y - x, 2) }.sum)
}

我收到类型错误:

df.withColumn("distances" , distance($"col1",$"col2"))
<console>:68: error: type mismatch;
 found   : org.apache.spark.sql.ColumnName
 required: Array[Double]
       ids_with_predictions_centroids3.withColumn("distances" , distance($"col1",$"col2"))

我知道我必须遍历每一列的元素,但是我无法在任何地方找到有关如何执行此操作的说明。我是Scala编程的新手。

2 个答案:

答案 0 :(得分:4)

要在数据框上使用自定义函数,您需要将其定义为UDF。例如,可以这样做,如下所示:

val distance = udf((xs: WrappedArray[Double], ys: WrappedArray[Double]) => {
  math.sqrt((xs zip ys).map { case (x,y) => math.pow(y - x, 2) }.sum)
})

df.withColumn("C", distance($"A", $"B")).show()

请注意,此处需要使用WrappedArray(或Seq)。

结果数据框:

+----------+----------+------------------+
|         A|         B|                 C|
+----------+----------+------------------+
|[1.0, 2.0]|[1.0, 3.0]|               1.0|
|[2.0, 3.0]|[3.0, 4.0]|1.4142135623730951|
+----------+----------+------------------+

答案 1 :(得分:3)

Spark函数在基于列的上工作,并且您唯一的错误是您在函数中混合了列和基元

错误消息非常清楚,表明您正在 distance 函数中传递一列,即$"col1"$"col2"但是 distance 函数定义为{<1>}并采用原始类型

解决方案是使距离函数完全基于列

distance(xs: Array[Double], ys: Array[Double])

应该可以为您提供正确的结果而不会出错

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions._

def distance(xs: Column, ys: Column) = {
  sqrt(pow(ys(0)-xs(0), 2) + pow(ys(1)-xs(1), 2))
}

df.withColumn("distances" , distance($"col1",$"col2")).show(false)

我希望答案会有所帮助