我有以下功能,我希望结果的类型为Dataset[Point]
或Array[Point]
。但它返回Dataset[Array[Point]]
。
另外,我想通过Point.Point >= 8
过滤结果。调用过滤函数的最佳位置在哪里?
def compare2(dbo: Dataset[Cols], ods: Array[Cols]) = {
import dbo.sparkSession.implicits._
dbo.mapPartitions(p => p.map(l => ods.map(r =>
Point(l.Id, r.Id, getPoint(l, r))))
//.filter(p => p.Point >= 8) // p is Array[Point]
)
}
case class Cols (Id: Int, F1: String, F2: String, F3: String)
case class Point (Id1: Int, Id2: Int, Point: Int)
答案 0 :(得分:2)
dbo.mapPartitions
需要func: (Iterator[T]) ⇒ Iterator[U]
(删除隐式Encoder
以使事情更清晰。)
mapPartitions [U](func:(Iterator [T])⇒迭代器[U]):数据集[U] 返回一个新的数据集,其中包含将func应用于每个分区的结果。< / p>
这样,p
内的mapPartitions
类型为Iterator[Cols]
。
p.map(l
提供l
类型Cols
和类型Iterator[T]
的结果。
你正在制作Iterator[Iterator[T]]
,但这还不够:(
由于ods: Array[Cols]
ods.map(r
所在的部分Array[Point]
提供了dbo.mapPartitions { p: Iterator[Cols] =>
p.map { l: Cols =>
ods.map { r: Cols =>
Point(l.Id, r.Id, getPoint(l, r) } } }
。
考虑到这一切,你有一个巨大的心理任务来理解这里发生了什么,并可以重写为以下代码:
dbo.mapPartitions { p: Iterator[Cols] =>
for {
l <- p
r <- ods
} yield Point(l.Id, r.Id, getPoint(l, r))
}
为了让事情变得更容易(尤其是未来的代码读者),我建议您使用Scala的For Comprehension进行另一次重写:
Scala提供了一种用于表达序列理解的轻量级符号。
我建议如下:
if
过滤非常简单,需要一个plotly
作为理解的一部分。
答案 1 :(得分:0)
您可能想要使用flatMap:
dbo.mapPartitions(p => p.flatMap(l => ods.map(r =>
Point(l.Id, r.Id, getPoint(l, r))))
.filter(p => p.Point >= 8) // p is Array[Point]
您基本上将每个cols转换为一个点数组,因为最后一部分是数组[Cols] =&gt; Array [Points]的映射。如果是flatmap,它将使数据集将这些数组展平为元素。一旦你这样做,过滤器应该正常工作。