如何在使用spark流上下文时将Seq转换为RDD

时间:2016-04-26 09:52:56

标签: scala apache-spark spark-streaming

我正在使用TestSuiteBase使用spark-streaming创建一些测试(使用spark streaming context scc)。然后我使用output: Seq[Seq[(Double, Double)]]创建虚拟数据。最后,我想将一些函数应用于output,但此函数接受RDD[(Double, Double)],而不是Seq[Seq[(Double, Double)]]

要解决此问题,我打算使用val rdd: RDD[(Double, Double)] = sc.parallelize(output.flatten),但是我应该从sc如何以及在何处获得引发上下文scc?或者,也许有没有办法在RDD中直接创建虚拟数据而不使用Seq

class StreamingTestLR  extends SparkFunSuite
                       with TestSuiteBase {

  // use longer wait time to ensure job completion
  override def maxWaitTimeMillis: Int = 20000

  var ssc: StreamingContext = _

  override def afterFunction() {
    super.afterFunction()
    if (ssc != null) {
      ssc.stop()
    }
  }

//...

val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)

// THE PROBLEM IS HERE!!!
// val metrics = new SomeFuncThatAcceptsRDD(rdd)

}

更新

  // Test if the prediction accuracy of increases when using hyper-parameter optimization
  // in order to learn Y = 10*X1 + 10*X2 on streaming data
  test("Test 1") {
    // create model initialized with zero weights
    val model = new StreamingLinearRegressionWithSGD()
      .setInitialWeights(Vectors.dense(0.0, 0.0))
      .setStepSize(0.2)
      .setNumIterations(25)

    // generate sequence of simulated data for testing
    val numBatches = 10
    val nPoints = 100
    val testInput = (0 until numBatches).map { i =>
      LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
    }
    val inputDStream = DStream[LabeledPoint]

    withStreamingContext(setupStreams(testInput, inputDStream)) { ssc =>
      model.trainOn(inputDStream)
      model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
      val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)


      val rdd: RDD[(Double, Double)] = ssc.sparkContext.parallelize(output.flatten)

      // Instantiate metrics object
      val metrics = new RegressionMetrics(rdd)

      // Squared error
      println(s"MSE = ${metrics.meanSquaredError}")
      println(s"RMSE = ${metrics.rootMeanSquaredError}")

      // R-squared
      println(s"R-squared = ${metrics.r2}")

      // Mean absolute error
      println(s"MAE = ${metrics.meanAbsoluteError}")

      // Explained variance
      println(s"Explained variance = ${metrics.explainedVariance}")
    }
  }

1 个答案:

答案 0 :(得分:3)

试试这个:

 class MyTestSuite extends TestSuiteBase with BeforeAndAfter {

  test("my test") {
    withTestServer(new TestServer()) { testServer =>
      // Start the server
      testServer.start()
      // Set up the streaming context and input streams
      withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
        val rdd = ssc.sparkContext.parallelize(output.flatten)
        // your code here 
        testServer.stop()
        ssc.stop()
      }
     }
    }
 }

此处有更多详情:https://github.com/apache/spark/blob/master/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala