我想尝试实施
def grouped(size: Int): Iterator[Repr]
拥有Seq
,但在Spark中只有Dataset
。
因此输入应为ds: Dataset[A], size: Int
,输出应为Seq[Dataset[A]]
,其中输出中的每个Dataset[A]
都不能大于size
。
我应该如何进行?我尝试使用repartition
和mapPartitions
,但不确定从那里去哪里。
谢谢。
编辑:我在glom
中找到了RDD
方法,但它产生了一个RDD[Array[A]]
,我该如何从{{1 }}?
答案 0 :(得分:1)
您要去的地方
/*
{"countries":"pp1"}
{"countries":"pp2"}
{"countries":"pp3"}
{"countries":"pp4"}
{"countries":"pp5"}
{"countries":"pp6"}
{"countries":"pp7"}
*/
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.{SparkConf, SparkContext};
object SparkApp extends App {
override def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("Simple Application").setMaster("local").set("spark.ui.enabled", "false")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
val dataFrame: DataFrame = sqlContext.read.json("/data.json")
val k = 3
val windowSpec = Window.partitionBy("grouped").orderBy("countries")
val newDF = dataFrame.withColumn("grouped", lit("grouping"))
var latestDF = newDF.withColumn("row", row_number() over windowSpec)
val totalCount = latestDF.count()
var lowLimit = 0
var highLimit = lowLimit + k
while(lowLimit < totalCount){
latestDF.where(s"row <= $highLimit and row > $lowLimit").show(false)
lowLimit = lowLimit + k
highLimit = highLimit + k
}
}
}
答案 1 :(得分:0)
这是我找到的解决方案,但不确定是否可以可靠地工作:
override protected def batch[A](
input: Dataset[A],
batchSize: Int
): Seq[Dataset[A]] = {
val count = input.count()
val partitionQuantity = Math.ceil(count / batchSize).toInt
input.randomSplit(Array.fill(partitionQuantity)(1.0 / partitionQuantity), seed = 0)
}