如何映射数据集和数组以生成新行(给出数据集[数组[点]]而不是数据集[点])?

时间:2017-07-20 20:11:00

标签: scala apache-spark apache-spark-sql

我有以下功能,我希望结果的类型为Dataset[Point]Array[Point]。但它返回Dataset[Array[Point]]

另外,我想通过Point.Point >= 8过滤结果。调用过滤函数的最佳位置在哪里?

  def compare2(dbo: Dataset[Cols], ods: Array[Cols]) = {
    import dbo.sparkSession.implicits._
    dbo.mapPartitions(p => p.map(l => ods.map(r =>
      Point(l.Id, r.Id, getPoint(l, r))))
      //.filter(p => p.Point >= 8) // p is Array[Point]
    )
  }

case class Cols (Id: Int, F1: String, F2: String, F3: String)
case class Point (Id1: Int, Id2: Int, Point: Int)

2 个答案:

答案 0 :(得分:2)

dbo.mapPartitions需要func: (Iterator[T]) ⇒ Iterator[U](删除隐式Encoder以使事情更清晰。)

  

mapPartitions [U](func:(Iterator [T])⇒迭代器[U]):数据集[U] 返回一个新的数据集,其中包含将func应用于每个分区的结果。< / p>

这样,p内的mapPartitions类型为Iterator[Cols]

p.map(l提供l类型Cols和类型Iterator[T]的结果。

你正在制作Iterator[Iterator[T]],但这还不够:(

由于ods: Array[Cols] ods.map(r所在的部分Array[Point]提供了dbo.mapPartitions { p: Iterator[Cols] => p.map { l: Cols => ods.map { r: Cols => Point(l.Id, r.Id, getPoint(l, r) } } }

考虑到这一切,你有一个巨大的心理任务来理解这里发生了什么,并可以重写为以下代码:

dbo.mapPartitions { p: Iterator[Cols] =>
  for {
    l <- p
    r <- ods
  } yield Point(l.Id, r.Id, getPoint(l, r))
}

为了让事情变得更容易(尤其是未来的代码读者),我建议您使用Scala的For Comprehension进行另一次重写:

  

Scala提供了一种用于表达序列理解的轻量级符号。

我建议如下:

if

过滤非常简单,需要一个plotly作为理解的一部分。

答案 1 :(得分:0)

您可能想要使用flatMap:

 dbo.mapPartitions(p => p.flatMap(l => ods.map(r =>
   Point(l.Id, r.Id, getPoint(l, r))))
  .filter(p => p.Point >= 8) // p is Array[Point]

您基本上将每个cols转换为一个点数组,因为最后一部分是数组[Cols] =&gt; Array [Points]的映射。如果是flatmap,它将使数据集将这些数组展平为元素。一旦你这样做,过滤器应该正常工作。