在Spark DataFrame中使用UDT时出现异常

时间:2015-05-15 13:01:21

标签: apache-spark apache-spark-sql

我试图在spark sql中创建用户定义的类型,但我收到: 即使使用他们的示例,也无法将com.ubs.ged.risk.stdout.spark.ExamplePointUDT强制转换为org.apache.spark.sql.types.StructType。有没有人做过这项工作?

我的代码:

test("udt serialisation") {
    val points = Seq(new ExamplePoint(1.3, 1.6), new ExamplePoint(1.3, 1.8))
    val df = SparkContextForStdout.context.parallelize(points).toDF()
}

@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) 
case class ExamplePoint(val x: Double, val y: Double)

/**
 * User-defined type for [[ExamplePoint]].
 */
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

  override def sqlType: DataType = ArrayType(DoubleType, false)

  override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

  override def serialize(obj: Any): Seq[Double] = {
    obj match {
      case p: ExamplePoint =>
        Seq(p.x, p.y)
    }
  }

  override def deserialize(datum: Any): ExamplePoint = {
    datum match {
      case values: Seq[_] =>
        val xy = values.asInstanceOf[Seq[Double]]
        assert(xy.length == 2)
        new ExamplePoint(xy(0), xy(1))
      case values: util.ArrayList[_] =>
        val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
        new ExamplePoint(xy(0), xy(1))
    }
  }

  override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]

}

有用的stackstrace是这样的:

com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
java.lang.ClassCastException: com.ubs.ged.risk.stdout.spark.ExamplePointUDT cannot be cast to org.apache.spark.sql.types.StructType
    at org.apache.spark.sql.SQLContext.createDataFrame(SQLContext.scala:316)
    at org.apache.spark.sql.SQLContext$implicits$.rddToDataFrameHolder(SQLContext.scala:254)

1 个答案:

答案 0 :(得分:1)

似乎UDT需要在另一个类中使用(作为字段的类型)。直接使用它的一个解决方案是将其包装到Tuple1:

  test("udt serialisation") {
    val points = Seq(new Tuple1(new ExamplePoint(1.3, 1.6)), new Tuple1(new ExamplePoint(1.3, 1.8)))
    val df = SparkContextForStdout.context.parallelize(points).toDF()
    df.collect().foreach(println(_))
  }