避免两次指定架构(Spark / scala)

时间:2018-11-08 11:10:46

标签: scala apache-spark apache-spark-sql apache-spark-dataset

我需要按特定顺序遍历数据帧,并应用一些复杂的逻辑来计算新列。

我的强烈偏爱是以通用方式执行此操作,因此我不必列出一行的所有列并进行df.as[my_record]case Row(...) =>的操作,如图here所示。相反,我想按行名称访问行列,只需将结果列添加到源行。

下面的方法很好用,但是我想避免两次指定架构:第一次,这样我可以在迭代时按名称访问列,第二次来处理输出。

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

val q = """
select 2 part, 1 id
union all select 2 part, 4 id
union all select 2 part, 3 id
union all select 2 part, 2 id
"""
val df = spark.sql(q)

def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  if (iter.hasNext) {
    def complex_logic(p: Int): Integer = if (p == 3) null else p * 10;

    val head = iter.next
    val schema = StructType(head.schema.fields :+ StructField("result", IntegerType))
    val r =
      new GenericRowWithSchema((head.toSeq :+ complex_logic(head.getAs("id"))).toArray, schema)

    iter.scanLeft(r)((r1, r2) =>
      new GenericRowWithSchema((r2.toSeq :+ complex_logic(r2.getAs("id"))).toArray, schema)
    )
  } else iter
}

val schema = StructType(df.schema.fields :+ StructField("result", IntegerType))
val encoder = RowEncoder(schema)
df.repartition($"part").sortWithinPartitions($"id").mapPartitions(f_row)(encoder).show

应用mapPartitions后丢失了哪些信息,因此如果没有显式编码器就无法处理输出?如何避免指定它?

3 个答案:

答案 0 :(得分:0)

  

在应用mapPartitions之后丢失了哪些信息,因此如果没有这些信息,就无法处理输出

信息几乎不会丢失-从一开始就不存在-RowInternalRow的子类基本上是无类型的可变形状容器,它们不提供任何有用的类型信息,可以用来导出Encoder

schema中的

GenericRowWithSchema无关紧要,因为它是根据元数据而非类型来描述内容的。

  

如何避免指定它?

对不起,您不走运。如果要以静态类型的语言使用动态类型的构造(一袋Any),则必须付出代价,这里提供了Encoder

答案 1 :(得分:0)

好的-我已经检查了一些火花代码,并且将.mapPartitions与Dataset API结合使用不需要我显式构建/传递编码器。

您需要以下内容:

case class Before(part: Int, id: Int)
case class After(part: Int, id: Int, newCol: String)

import spark.implicits._

// Note column names/types must match case class constructor parameters.
val beforeDS = <however you obtain your input DF>.as[Before]

def f_row(it: Iterator[Before]): Iterator[After] = ???

beforeDS.reparition($"part").sortWithinPartitions($"id").mapPartitions(f_row).show

答案 2 :(得分:0)

我发现以下解释足够了,也许对其他人有用。

mapPartitions需要Encoder,因为否则它不能从迭代器或Dataset构造Row。即使每行都有一个架构,但Dataset[U]的构造函数无法导出(使用)该shema。

  def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
    new Dataset[U](
      sparkSession,
      MapPartitions[T, U](func, logicalPlan),
      implicitly[Encoder[U]])
  }

另一方面,由于不更改原始列的结构(元数据),因此无需调用mapPartitions,Spark可以使用从初始查询派生的架构。

我在此答案中描述了替代方法:https://stackoverflow.com/a/53177628/7869491