将命令式解决方案转换为并发映射解决方案

时间:2014-04-08 19:20:47

标签: scala

下面的代码是为列表中的每个用户运行一些代码的实现。代码只是比较每个用户并连接它们的属性:

  case class UserObj(id: String, nCoordinate : String)

  val userList = List(UserObj("a1" , "1234"),UserObj("a2" , "1234"), UserObj("a3" , "1234"))
val map1 = new java.util.concurrent.ConcurrentHashMap[String, Double]

    userList.par.map(xUser => {
        userList.par.map(yUser => {
        if (!xUser.id.isEmpty() && !yUser.id.isEmpty()) {
          println("Total is "+xUser.id+yUser.id+","+xUser.nCoordinate+yUser.nCoordinate)
          map1.put(xUser.id + "," + yUser.id , getJaccardDistance(xUser.nCoordinate, yUser.nCoordinate))
        }

      })
        println("")
  })                                              //> Total is a1a1,12341234
                                                  //| Total is a3a1,12341234
                                                  //| Total is a2a1,12341234
                                                  //| Total is a3a2,12341234
                                                  //| Total is a1a2,12341234
                                                  //| Total is a3a3,12341234
                                                  //| Total is a2a2,12341234
                                                  //| 
                                                  //| Total is a1a3,12341234
                                                  //| Total is a2a3,12341234
                                                  //| 
                                                  //| 
                                                  //| res0: scala.collection.parallel.immutable.ParSeq[Unit] = ParVector((), (), (
                                                  //| ))

  def getJaccardDistance(str1: String, str2: String) = {

    val zipped = str1.zip(str2)
    val numberOfEqualSequences = zipped.count(_ == ('1', '1')) * 2

    val p = zipped.count(_ == ('1', '1')).toFloat * 2
    val q = zipped.count(_ == ('1', '0')).toFloat * 2
    val r = zipped.count(_ == ('0', '1')).toFloat * 2
    val s = zipped.count(_ == ('0', '0')).toFloat * 2

    (q + r) / (p + q + r)

  }

这是以前必不可少的解决方案:

     for (xUser <- userList) {

         for (yUser <- userList) {
        if (!xUser.id.isEmpty() && !yUser.id.isEmpty()) {
          println("Total is "+xUser.id+yUser.id+","+xUser.nCoordinate+yUser.nCoordinate)
        }

      }
        println("")
   } 

但是我想利用Scala的并行集合,我认为使用map来实现这一点的推荐方法。由于上面的命令性代码可能导致多个线程运行相同的代码。注意:上面的代码正在执行:println("Total is "+xUser.id+yUser.id+","+xUser.nCoordinate+yUser.nCoordinate)只是实际运行的算法的一个更简单的版本。

在问题开头发布的功能解决方案的行为符合预期,但一旦列表包含更多的3000个元素,它几乎可以停止。为什么会这样?我的实施是否正确?

1 个答案:

答案 0 :(得分:1)

除非您提供实际算法,否则我们只能猜测。我尝试使用3000个元素并且它运行良好,虽然比简单的地图慢。

为什么慢?因为println已同步。看看java.io.PrintStream

public void println(String x) {
    synchronized (this) {
        print(x);
        newLine();
    }
}

所以显然并行化println并没有多大意义。您可以分享您的算法,这样我们就可以看到封面下发生了什么,或深入了解代码以确保没有任何内容同步(例如,如果您println - 某处,请考虑使用asynchronous logging代替

我用来测试的代码是:

case class UserObj(id: String, nCoordinate : String)

val userList = (1 to 3000).map(i => UserObj("a"+i.toString, "1234"))

var timings = new mutable.StringBuilder
def time[R](block: => R): R = {
  val t0 = System.nanoTime()
  val result = block
  val t1 = System.nanoTime()
  timings.append("Elapsed time: " + (t1 - t0) + "ns\n")
  result
}


time {
  userList.map(xUser => {
    userList.map(yUser => {
      if (!xUser.id.isEmpty && !yUser.id.isEmpty) {
        println("Total is " + xUser.id + yUser.id + "," + xUser.nCoordinate + yUser.nCoordinate)
      }
    })
  })
}

time {
  userList.par.map(xUser => {
    userList.par.map(yUser => {
      if (!xUser.id.isEmpty && !yUser.id.isEmpty) {
        println("Total is " + xUser.id + yUser.id + "," + xUser.nCoordinate + yUser.nCoordinate)
      }
    })
  })
}

println(timings.toString())

并返回以下时间:

Elapsed time: 29066452631ns
Elapsed time: 37031631461ns