RDD / Dataframe的可迭代视图

时间:2015-11-22 09:44:57

标签: scala apache-spark legacy

在迁移过程中,我们必须包装遗留代码,我们正在寻找迭代RDD和Dataframe的所有实例的方法,而无需将其复制到collect()的大数组。

以下是遗留函数数据需要Iterable所有数据的示例。

import org.apache.spark._
import org.apache.spark.mllib.linalg.{Vector,Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.junit.Test
import org.scalatest.junit.AssertionsForJUnit

class SparkLegacy extends AssertionsForJUnit {

  def legacyFunction(data: Iterable[Vector]) {
    data.foreach(println(_))
  }

  @Test def testRDD() {
    // set up
    val sparkConf = new SparkConf().setMaster("local").setAppName("Example")
    val sc = new SparkContext(sparkConf)
    val data = Array(Vectors.dense(0.0, 1.1, 0.1), Vectors.dense(2.0, 1.0, -1.0))
    val RDD: RDD[Vector] = sc.parallelize(data, 2)

    // this is possible, but may create a huge temporary collection
    // I would prefer to pass an Iterable view of the RDD to the legacy function
    legacyFunction(RDD.collect())
  }

  @Test def testDataframe() {
    // set up
    val sparkConf = new SparkConf().setMaster("local").setAppName("Example")
    val sc = new SparkContext(sparkConf)
    val sqlContext = new SQLContext(sc)

    val data = sqlContext.createDataFrame(Seq(
      (1.0, Vectors.dense(0.0, 1.1, 0.1)),
      (0.0, Vectors.dense(2.0, 1.0, -1.0)),
      (0.0, Vectors.dense(2.0, 1.3, 1.0)),
      (1.0, Vectors.dense(0.0, 1.2, -0.5)))).toDF("label", "features")

    val col = data.col("features")
    // I have no idea how to efficiently call the legacy function
  }

}

1 个答案:

答案 0 :(得分:0)

基于David Maust评论:

import org.apache.spark._
import org.apache.spark.mllib.linalg.{ Vector, Vectors }
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.junit.Test
import org.scalatest.junit.AssertionsForJUnit

class SparkLegacy extends AssertionsForJUnit {
  // legacy function to call
  def legacyFunction(data: Iterable[Vector]) {
    data.foreach(println(_))
  }  

  // Enrich RDD with Iterable interface
  implicit class RDDIterable[T](rdd: RDD[T]) extends Iterable[T] {
    override def iterator: Iterator[T] = rdd.toLocalIterator
  }

  @Test def testRDD() {
    // set up
    val sparkConf = new SparkConf().setMaster("local").setAppName("Example")
    val sc = new SparkContext(sparkConf)
    val data = Array(Vectors.dense(0.0, 1.1, 0.1), Vectors.dense(2.0, 1.0, -1.0))
    val RDD: RDD[Vector] = sc.parallelize(data, 2)

    // call legacy function with implicit Iterable
    legacyFunction(RDD)
  }

  @Test def testDataframe() {
    // set up
    val sparkConf = new SparkConf().setMaster("local").setAppName("Example")
    val sc = new SparkContext(sparkConf)
    val sqlContext = new SQLContext(sc)

    val data = sqlContext.createDataFrame(Seq(
      (1.0, Vectors.dense(0.0, 1.1, 0.1)),
      (0.0, Vectors.dense(2.0, 1.0, -1.0)),
      (0.0, Vectors.dense(2.0, 1.3, 1.0)),
      (1.0, Vectors.dense(0.0, 1.2, -0.5)))).toDF("label", "features")

    // convert to RDD
    val RDD = data.select("features").map { case Row(features: Vector) => features }

    // call legacy function with implicit Iterable
    legacyFunction(RDD)
  }

}

然而,我是Spark的新手并不知道这是多么有效,因此如果专家可以批准或改进这个答案,我们会欢迎。