如何使用Scala Spark在RDD中获取多个相邻数据

时间:2017-06-03 06:04:23

标签: scala apache-spark apache-spark-sql

我有一个rddRDD的值为0或1,限制为4.当我映射rdd时,如果input : 1,0,0,0,0,0,1,0,0 expected output : 1,1,1,1,0,0,1,1,1 为值为1则从当前位置到(当前位置+限制)的值都是1,否则有0 0。 示例

val rdd = sc.parallelize(Array(1, 0, 0, 0, 0, 0, 1, 0, 0))
val limit = 4
val resultlimit = rdd.mapPartitions(parIter => {
  var result = new ArrayBuffer[Int]()
  var resultIter = new ArrayBuffer[Int]()
  while (parIter.hasNext) {
    val iter = parIter.next()
    resultIter.append(iter)
  }
  var i = 0
  while (i < resultIter.length) {
    result.append(resultIter(i))
    if (resultIter(i) == 1) {
      var j = 1
      while (j + i < resultIter.length && j < limit) {
        result.append(1)
        j += 1
      }
      i += j
    } else {
      i += 1
    }
  }
  result.toIterator
})
resultlimit.foreach(println)

这是我到目前为止所尝试的:

RDD:[1,1,1,1,0,0,1,1,1]

resultlimit的结果是{{1}}

我快速而肮脏的方法是首先创建一个数组,但这是如此丑陋和低效。

有没有更清洁的解决方案?

2 个答案:

答案 0 :(得分:1)

简单明了。导入RDDFunctions

import org.apache.spark.mllib.rdd.RDDFunctions._

定义限制:

val limit: Int = 4

limit - 1个零填充到第一个分区:

val extended = rdd.mapPartitionsWithIndex {
  case (0, iter) => Seq.fill(limit - 1)(0).toIterator ++ iter
  case (_, iter) => iter
}

滑过RDD

val result = extended.sliding(limit).map {
  slice => if (slice.exists(_ != 0)) 1 else 0
}

检查结果:

val expected = Seq(1,1,1,1,0,0,1,1,1)
require(expected == result.collect.toSeq)

另一方面,您当前的方法无法纠正分区边界,因此结果将根据来源而有所不同。

答案 1 :(得分:0)

以下是针对您的要求的改进方法。三个while循环减少为一个for循环,两个ArrayBuffer减少为一个ArrayBuffer。因此减少了处理时间和内存使用量。

val resultlimit= rdd.mapPartitions(parIter => {
  var result = new ArrayBuffer[Int]()
  var limit = 0
  for (value <- parIter) {
    if (value == 1) limit = 4
    if (limit > 0) {
      result.append(1)
      limit -= 1
    }
    else {
      result.append(value)
    }
  }
  result.toIterator
})

<强>被修改

以上解决方案是指您未在原始partition中定义rdd。但是当分区定义为

val rdd = sc.parallelize(Array(1,1,0,0,0,0,1,0,0), 4)

我们需要collect rdds上述解决方案将在每个partitions上执行。

所以以下解决方案应该有效

 var result = new ArrayBuffer[Int]()
  var limit = 0
  for (value <- rdd.collect()) {
    if (value == 1) limit = 4
    if (limit > 0) {
      result.append(1)
      limit -= 1
    }
    else {
      result.append(value)
    }
  }
result.foreach(println)