Spark SQL不调用我的UDT equals / hashcode方法

时间:2019-05-06 10:52:35

标签: apache-spark apache-spark-sql user-defined-types

我想用我在Spark SQL中定义的数据类型实现比较运算符(等于,哈希码,排序)。尽管Spark SQL UDT仍然是私有的,但我还是遵循this之类的一些示例来解决这种情况。

我有一个名为MyPoint的类:

@SQLUserDefinedType(udt = classOf[MyPointUDT])
case class MyPoint(x: Double, y: Double) extends Serializable {

  override def hashCode(): Int = {
    println("hash code")
    31 * (31 * x.hashCode()) + y.hashCode()
  }

  override def equals(other: Any): Boolean =  {
    println("equals")
    other match {
      case that: MyPoint => this.x == that.x && this.y == that.y
      case _ => false
    }
  }

然后,我有UDT类:

private class MyPointUDT extends UserDefinedType[MyPoint] {
  override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)

  override def serialize(obj: MyPoint): ArrayData = {
    obj match {
      case features: MyPoint =>
        new GenericArrayData2(Array(features.x, features.y))
    }
   }

   override def deserialize(datum: Any): MyPoint = {
    datum match {
      case data: ArrayData if data.numElements() == 2 => {
        val arr = data.toDoubleArray()
        new MyPoint(arr(0), arr(1))
      }
    }
   }

  override def userClass: Class[MyPoint] = classOf[MyPoint]

  override def asNullable: MyPointUDT = this
}

然后我创建一个简单的DataFrame:

val p1 = new MyPoint(1.0, 2.0)
val p2 = new MyPoint(1.0, 2.0)
val p3 = new MyPoint(10.0, 20.0)
val p4 = new MyPoint(11.0, 22.0)

val points = Seq(
  ("P1", p1),
  ("P2", p2),
  ("P3", p3),
  ("P4", p4)
).toDF("label", "point")

points.registerTempTable("points")
spark.sql("SELECT Distinct(point) FROM points").show()

问题是:为什么SQL查询不执行MyPoint类内部的equals方法?如何进行比较?在此示例中如何实现比较运算符?

0 个答案:

没有答案