如何在Spark中为数据集实现Seq.grouped(size:Int):Seq [Seq [A]]

时间:2019-12-27 12:09:00

标签: scala apache-spark

我想尝试实施 def grouped(size: Int): Iterator[Repr]拥有Seq,但在Spark中只有Dataset

因此输入应为ds: Dataset[A], size: Int,输出应为Seq[Dataset[A]],其中输出中的每个Dataset[A]都不能大于size

我应该如何进行?我尝试使用repartitionmapPartitions,但不确定从那里去哪里。

谢谢。

编辑:我在glom中找到了RDD方法,但它产生了一个RDD[Array[A]],我该如何从{{1 }}?

2 个答案:

答案 0 :(得分:1)

您要去的地方

/*
{"countries":"pp1"}
{"countries":"pp2"}
{"countries":"pp3"}
{"countries":"pp4"}
{"countries":"pp5"}
{"countries":"pp6"}
{"countries":"pp7"}
   */

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.{SparkConf, SparkContext};


object SparkApp extends App {

  override def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("Simple Application").setMaster("local").set("spark.ui.enabled", "false")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    val dataFrame: DataFrame = sqlContext.read.json("/data.json")

    val k = 3

    val windowSpec = Window.partitionBy("grouped").orderBy("countries")

    val newDF = dataFrame.withColumn("grouped", lit("grouping"))

    var latestDF = newDF.withColumn("row", row_number() over windowSpec)

    val totalCount = latestDF.count()
    var lowLimit = 0
    var highLimit = lowLimit + k

    while(lowLimit < totalCount){
      latestDF.where(s"row <= $highLimit and row > $lowLimit").show(false)
      lowLimit = lowLimit + k
      highLimit = highLimit + k
    }
  }
}

答案 1 :(得分:0)

这是我找到的解决方案,但不确定是否可以可靠地工作:

  override protected def batch[A](
    input:     Dataset[A],
    batchSize: Int
  ): Seq[Dataset[A]] = {
    val count = input.count()
    val partitionQuantity = Math.ceil(count / batchSize).toInt

    input.randomSplit(Array.fill(partitionQuantity)(1.0 / partitionQuantity), seed = 0)
  }