将Scala OOP代码转换为功能代码

时间:2018-07-30 13:14:48

标签: scala apache-spark functional-programming

我正在关注一些Spark教程,但令我沮丧的是,似乎讲师没有遵循Scala的最佳实践。他使用var和state更改,而我更喜欢使用不可变的数据。因此我更改了代码以更好地适应功能范式,但是由于我仍然是该语言的新手,所以我并不是100%等效于我的代码。

从本质上讲,该代码假定从1循环到10。每次迭代都对RDD进行.flatMap()操作,并检查LongAccumulator是否达到某个值。如果已达到该值,则它将中断循环,否则将在RDD上执行.reduceByKey(),然后开始下一次迭代。

我不确定的几件事:

  1. 如果我对LongAccumulator的实现是正确的(我尝试通过将其传递给函数来对其进行递增):在我的地图函数中
  2. 我是否必须将函数的结果(例如.flatMap()上的RDD存储在新的val中,还是只需调用rdd.flatMap()将结果存储在新的内存地址并将先前存在的val rdd指向新的内存地址(与不可变的Vectors的工作流程类似):在我的for循环中

对于那些想知道的人,这是在从.txt文件构造的图形上的广度优先搜索算法的实现。在讲师选择使用颜色的地方,我以后使用1-3的数字来简化代码。我还选择在Vectors上使用Arrays

初始化变量:

//My code, outside main class


type NodeMetaData = (Vector[Int], Int, Int)
type Node = (Int, NodeMetaData)

//Within main class


val sparkContext = new SparkContext("local[*]", "TestContext")

val hitCounter = sparkContext.longAccumulator("counter")



//Tutorial code, outside main class


// We make our accumulator a "global" Option so we can reference it in a 
mapper later.
var hitCounter:Option[LongAccumulator] = None

// Some custom data types 
// BFSData contains an array of hero ID connections, the distance, and 
color.
type BFSData = (Array[Int], Int, String)
// A BFSNode has a heroID and the BFSData associated with it.
type BFSNode = (Int, BFSData)


//Within main class


val sc = new SparkContext("local[*]", "DegreesOfSeparation") 

// Our accumulator, used to signal when we find the target 
// character in our BFS traversal.
hitCounter = Some(sc.longAccumulator("Hit Counter"))

我的for循环:

val normalisedValues = graphText.map(x => parseLine(x, 5306)) //rdd
breakable {
      for (n <- 1 to 10) {
        println("Running bfs iteration#"+n)
        normalisedValues.flatMap(x => mapBfs(x, 14, hitCounter))
        println("Processing " + normalisedValues.count() + " values\n")

        if (hitCounter.isRegistered) {
          if (hitCounter.value > 0) {
            println("Found target from " + hitCounter.value + "different directions")
            break
          } else {
            normalisedValues.reduceByKey(reduceBfs)
          }
        }
      }
    }

循环教程:

var iterationRdd = createStartingRdd(sc)
var iteration:Int = 0
    for (iteration <- 1 to 10) {
      println("Running BFS Iteration# " + iteration)

      // Create new vertices as needed to darken or reduce distances in the
      // reduce stage. If we encounter the node we're looking for as a GRAY
      // node, increment our accumulator to signal that we're done.
      val mapped = iterationRdd.flatMap(bfsMap)

      // Note that mapped.count() action here forces the RDD to be evaluated, and
      // that's the only reason our accumulator is actually updated.  
      println("Processing " + mapped.count() + " values.")

      if (hitCounter.isDefined) {
        val hitCount = hitCounter.get.value
        if (hitCount > 0) {
          println("Hit the target character! From " + hitCount + 
              " different direction(s).")
          return
        }
      }

      // Reducer combines data for each character ID, preserving the darkest
      // color and shortest path.      
      iterationRdd = mapped.reduceByKey(bfsReduce)
    }
  }

我的地图功能:

def mapBfs(node:Node, targetId:Int, counter: LongAccumulator): Vector[Node] = {
    val relations = node._2
    val results:Vector[Node] = Vector.empty[Node]
    if(relations._2 == 2) {
      relations._1.foreach(x => {
        if (x == targetId) {
          if (counter.isRegistered) counter.add(1) else None
        }
        results :+ (x, (Vector(), counter.value.toInt, 2))

      })

    }

    results :+ (node._1, (relations._1, relations._2, if(relations._2 == 2) 3 else 2))
  }

教程地图功能:

/** Expands a BFSNode into this node and its children */
  def bfsMap(node:BFSNode): Array[BFSNode] = {

    // Extract data from the BFSNode
    val characterID:Int = node._1
    val data:BFSData = node._2

    val connections:Array[Int] = data._1
    val distance:Int = data._2
    var color:String = data._3

    // This is called from flatMap, so we return an array
    // of potentially many BFSNodes to add to our new RDD
    var results:ArrayBuffer[BFSNode] = ArrayBuffer()

    // Gray nodes are flagged for expansion, and create new
    // gray nodes for each connection
    if (color == "GRAY") {
      for (connection <- connections) {
        val newCharacterID = connection
        val newDistance = distance + 1
        val newColor = "GRAY"

        // Have we stumbled across the character we're looking for?
        // If so increment our accumulator so the driver script knows.
        if (targetCharacterID == connection) {
          if (hitCounter.isDefined) {
            hitCounter.get.add(1)
          }
        }

        // Create our new Gray node for this connection and add it to the results
        val newEntry:BFSNode = (newCharacterID, (Array(), newDistance, newColor))
        results += newEntry
      }

      // Color this node as black, indicating it has been processed already.
      color = "BLACK"
    }

    // Add the original node back in, so its connections can get merged with 
    // the gray nodes in the reducer.
    val thisEntry:BFSNode = (characterID, (connections, distance, color))
    results += thisEntry

    return results.toArray
  }

1 个答案:

答案 0 :(得分:0)

解决方案:

type NodeMetaData = (Vector[Int], Int, Int)
  type Node = (Int, NodeMetaData)



  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.ERROR)

    val startID:Int = 5306 //Spiderman
    val targetID:Int = 14 //Adam 3031

    val sparkContext = new SparkContext("local[*]", "TestContext")

    val hitCounter = sparkContext.longAccumulator("counter")

    val nameText = sparkContext.textFile("SparkScala/Marvel-names.txt")
    val nameLookup = nameText.flatMap(mapName)
    val graphText = sparkContext.textFile("SparkScala/Marvel-graph.txt")

    var normalisedValues = graphText.map(x => createNode(x, startID)).reduceByKey((x,y) => (x._1.union(y._1),x._2,x._3))

    for(n <- 1 to 10){
      val searchables: Vector[Int] = normalisedValues.filter(_._2._3 == 3).reduce((x,y) => (0,(x._2._1.union(y._2._1),0,0)))._2._1
      val broadcast = sparkContext.broadcast(searchables)
      val currentDistance = sparkContext.broadcast(n)
      normalisedValues = normalisedValues.map(x => if(broadcast.value.contains(x._1) && x._2._3 != 3) markNode(x, currentDistance.value) else x).map(conquerNode)
    }

    val collection = normalisedValues.collect()

    collection.sortWith((x,y) => x._2._2 < y._2._2 ).foreach(x =>{
      if(x._1 == targetID) {
        println(f"${nameLookup.lookup(targetID).head} is ${x._2._2} degrees away from ${nameLookup.lookup(startID).head}")
      }
    })
  }


  def mapName(line:String): Option[(Int, String)]={
    val values = line.split(" ", 2)
    val name = values(1).replace("\"", "")
    Some(values(0).toInt, name)
  }

  def createNode(line: String, startId:Int): Node ={
    val connections = line.split("\\s+")
    val id = connections(0).toInt
    val relations = connections.map(_.toInt).filter( _ != id ).toVector
    val color = if ( id == startId ) 3 else 1
    val distance = if ( id == startId )  0 else 9999
    (id, ( relations, distance, color ))

  }

  def markNode(node:Node, distance:Int): Node ={
    (node._1,(node._2._1, distance, 2))
  }


  def conquerNode(node:Node): Node = {
    if(node._2._3 == 2) {
      (node._1, (node._2._1, node._2._2, 3))
    } else { node }
  }

解决方案是创建两个单独的函数:一个将节点标记为灰色(颜色== 2)的函数,然后另一个通过将节点标记为黑色来“征服”该节点的函数。为了找出接下来要标记为灰色的节点;每个循环的开始都会广播所有先前标记的黑色节点的Vectors的并集。

此实现的一个缺点是使用var normalisedValues。可以通过将新的RDDs附加到Vector的尾部,并在下一次迭代的开始处简单读取尾部来避免这种情况,但是我不想这样做,因为这样会导致如果要存储每个RDD,则将产生大量的内存开销。但是,出于演示目的,将这样做。请参阅

  

https://raw.githubusercontent.com/maciejmarczak/spark-exercises/master/src/main/scala/org/maciejmarczak/scala/DegreesOfSeparation.scala

查看原始代码。