使用嵌套的用户数据类型保存Spark DataFrames

时间:2015-09-17 08:10:22

标签: apache-spark apache-spark-sql

我想保存(作为镶木地板文件)包含自定义类作为列的Spark DataFrame。该类由另一个自定义类的Seq组成。为此,我以与VectorUDT类似的方式为每个类创建一个UserDefinedType类。我可以按照我的意图使用数据框,但无法将其作为镶木地板(或jason)保存到磁盘 我把它报告为一个bug,但是我的代码可能有问题。我已经实现了一个更简单的例子来说明问题:

import org.apache.spark.sql.SaveMode
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._

@SQLUserDefinedType(udt = classOf[AUDT])
case class A(list:Seq[B])

class AUDT extends UserDefinedType[A] {
  override def sqlType: DataType = StructType(Seq(StructField("list", ArrayType(BUDT, containsNull = false), nullable = true)))
  override def userClass: Class[A] = classOf[A]
  override def serialize(obj: Any): Any = obj match {
    case A(list) =>
      val row = new GenericMutableRow(1)
      row.update(0, new GenericArrayData(list.map(_.asInstanceOf[Any]).toArray))
      row
  }

  override def deserialize(datum: Any): A = {
    datum match {
      case row: InternalRow => new A(row.getArray(0).toArray(BUDT).toSeq)
    }
  }
}

object AUDT extends AUDT

@SQLUserDefinedType(udt = classOf[BUDT])
case class B(num:Int)

class BUDT extends UserDefinedType[B] {
  override def sqlType: DataType = StructType(Seq(StructField("num", IntegerType, nullable = false)))
  override def userClass: Class[B] = classOf[B]
  override def serialize(obj: Any): Any = obj match {
    case B(num) =>
      val row = new GenericMutableRow(1)
      row.setInt(0, num)
      row
  }

  override def deserialize(datum: Any): B = {
    datum match {
      case row: InternalRow => new B(row.getInt(0))
    }
  }
}

object BUDT extends BUDT

object TestNested {
  def main(args:Array[String]) = {
    val col = Seq(new A(Seq(new B(1), new B(2))),
                  new A(Seq(new B(3), new B(4))))

    val sc = new SparkContext(new SparkConf().setMaster("local[1]").setAppName("TestSpark"))
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    import sqlContext.implicits._

    val df = sc.parallelize(1 to 2 zip col).toDF()
    df.show()

    df.write.mode(SaveMode.Overwrite).save(...)
  }
}

这会导致以下错误:

  

15/09/16 16:44:39错误执行者:阶段1.0中任务0.0的异常   (TID 1)java.lang.IllegalArgumentException:嵌套类型应该是   重复:必需的组数组{required int32 num; } 在   org.apache.parquet.schema.ConversionPatterns.listWrapper(ConversionPatterns.java:42)   在   org.apache.parquet.schema.ConversionPatterns.listType(ConversionPatterns.java:97)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convertField(CatalystSchemaConverter.scala:460)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convertField(CatalystSchemaConverter.scala:318)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter $$ anonfun $ convertField $ 1.适用(CatalystSchemaConverter.scala:522)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter $$ anonfun $ convertField $ 1.适用(CatalystSchemaConverter.scala:521)   在   scala.collection.IndexedSeqOptimized $ class.foldl(IndexedSeqOptimized.scala:51)   在   scala.collection.IndexedSeqOptimized $ class.foldLeft(IndexedSeqOptimized.scala:60)   在   scala.collection.mutable.ArrayOps $ ofRef.foldLeft(ArrayOps.scala:108)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convertField(CatalystSchemaConverter.scala:521)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convertField(CatalystSchemaConverter.scala:318)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convertField(CatalystSchemaConverter.scala:526)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convertField(CatalystSchemaConverter.scala:318)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter $$ anonfun $转换$ 1.适用(CatalystSchemaConverter.scala:311)   在   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter $$ anonfun $转换$ 1.适用(CatalystSchemaConverter.scala:311)   在   scala.collection.TraversableLike $$ anonfun $表$ 1.适用(TraversableLike.scala:244)   在   scala.collection.TraversableLike $$ anonfun $表$ 1.适用(TraversableLike.scala:244)   在scala.collection.Iterator $ class.foreach(Iterator.scala:727)at   scala.collection.AbstractIterator.foreach(Iterator.scala:1157)at   scala.collection.IterableLike $ class.foreach(IterableLike.scala:72)at   org.apache.spark.sql.types.StructType.foreach(StructType.scala:92)at   scala.collection.TraversableLike $ class.map(TraversableLike.scala:244)   在org.apache.spark.sql.types.StructType.map(StructType.scala:92)at   org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.convert(CatalystSchemaConverter.scala:311)   在   org.apache.spark.sql.execution.datasources.parquet.ParquetTypesConverter $ .convertFromAttributes(ParquetTypesConverter.scala:58)   在   org.apache.spark.sql.execution.datasources.parquet.RowWriteSupport.init(ParquetTableSupport.scala:55)   在   org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:288)   在   org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:262)   在   。org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter(ParquetRelation.scala:94)   在   org.apache.spark.sql.execution.datasources.parquet.ParquetRelation $$匿名$ 3.newInstance(ParquetRelation.scala:272)   在   org.apache.spark.sql.execution.datasources.DefaultWriterContainer.writeRows(WriterContainer.scala:234)   在   org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelation $$ anonfun $运行$ 1 $$ anonfun $ $应用MCV $ SP $ 3.apply(InsertIntoHadoopFsRelation.scala:150)   在   org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelation $$ anonfun $运行$ 1 $$ anonfun $ $应用MCV $ SP $ 3.apply(InsertIntoHadoopFsRelation.scala:150)   在org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)   在org.apache.spark.scheduler.Task.run(Task.scala:88)at   org.apache.spark.executor.Executor $ TaskRunner.run(Executor.scala:214)   在   java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)   在   java.util.concurrent.ThreadPoolExecutor中的$ Worker.run(ThreadPoolExecutor.java:615)   at java.lang.Thread.run(Thread.java:745)15/09/16 16:44:39 WARN   TaskSetManager:阶段1.0中丢失的任务0.0(TID 1,localhost):

如果保存带有B而不是A的数据帧没有问题,因为B没有嵌套的自定义类。我错过了什么吗?

1 个答案:

答案 0 :(得分:2)

我必须对您的代码进行四次更改才能使其正常工作(在Linux上的Spark 1.6.0中测试),我认为我主要解释他们为什么需要它们。然而,我确实发现自己想知道是否有更简单的解决方案。所有更改都在AUDT中,如下所示:

  1. 定义sqlType时,请依赖于BUDT.sqlType,而不仅仅是BUDT
  2. serialize()中,在每个列表元素上调用BUDT.serialize()
  3. deserialize()中:
    • 致电toArray(BUDT.sqlType)而非toArray(BUDT)
    • 在每个元素上调用BUDT.deserialize()
  4. 以下是生成的代码:

    class AUDT extends UserDefinedType[A] {
      override def sqlType: DataType =
        StructType(
          Seq(StructField("list",
                          ArrayType(BUDT.sqlType, containsNull = false),
                          nullable = true)))
    
      override def userClass: Class[A] = classOf[A]
    
      override def serialize(obj: Any): Any = 
        obj match {
          case A(list) =>
            val row = new GenericMutableRow(1)
            val elements =
              list.map(_.asInstanceOf[Any])
                  .map(e => BUDT.serialize(e))
                  .toArray
            row.update(0, new GenericArrayData(elements))
            row
        }
    
      override def deserialize(datum: Any): A = {
        datum match {
          case row: InternalRow => 
            val first = row.getArray(0)
            val bs:Array[InternalRow] = first.toArray(BUDT.sqlType)
            val bseq = bs.toSeq.map(e => BUDT.deserialize(e))
            val a = new A(bseq)
            a
        }
      }
    
    }
    

    所有四个更改都具有相同的特征:处理A和处理B之间的关系现在非常明确:用于模式类型,序列化和反序列化。原始代码似乎是基于这样的假设,即Spark SQL将只是弄清楚"这可能是合理的,但显然它并没有。