我想用我在Spark SQL中定义的数据类型实现比较运算符(等于,哈希码,排序)。尽管Spark SQL UDT仍然是私有的,但我还是遵循this之类的一些示例来解决这种情况。
我有一个名为MyPoint的类:
@SQLUserDefinedType(udt = classOf[MyPointUDT])
case class MyPoint(x: Double, y: Double) extends Serializable {
override def hashCode(): Int = {
println("hash code")
31 * (31 * x.hashCode()) + y.hashCode()
}
override def equals(other: Any): Boolean = {
println("equals")
other match {
case that: MyPoint => this.x == that.x && this.y == that.y
case _ => false
}
}
然后,我有UDT类:
private class MyPointUDT extends UserDefinedType[MyPoint] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
override def serialize(obj: MyPoint): ArrayData = {
obj match {
case features: MyPoint =>
new GenericArrayData2(Array(features.x, features.y))
}
}
override def deserialize(datum: Any): MyPoint = {
datum match {
case data: ArrayData if data.numElements() == 2 => {
val arr = data.toDoubleArray()
new MyPoint(arr(0), arr(1))
}
}
}
override def userClass: Class[MyPoint] = classOf[MyPoint]
override def asNullable: MyPointUDT = this
}
然后我创建一个简单的DataFrame:
val p1 = new MyPoint(1.0, 2.0)
val p2 = new MyPoint(1.0, 2.0)
val p3 = new MyPoint(10.0, 20.0)
val p4 = new MyPoint(11.0, 22.0)
val points = Seq(
("P1", p1),
("P2", p2),
("P3", p3),
("P4", p4)
).toDF("label", "point")
points.registerTempTable("points")
spark.sql("SELECT Distinct(point) FROM points").show()
问题是:为什么SQL查询不执行MyPoint类内部的equals方法?如何进行比较?在此示例中如何实现比较运算符?