斯卡拉计数词的共现表现真的很低

时间:2018-06-05 09:16:12

标签: python scala performance

当我尝试实现一个函数来计算scala中的单词共现时,我发现我的函数性能非常低。

同时出现的词是:
也就是说我们有一个List [List [Int]](实际上是单词列表的列表),
我们将为每个List [Int],
生成一个组合 然后我们将所有组合合并到一个地图中,并将每个重复键的值相加。

组合:
[0,1,2] - > [((0,1),1),((0,2),1),((1,2),1)]

合并组合:
[((0,1),1),((0,2),1),((1,2),1)] + [((0,1),1),((0,2), 1),((1,2),1)] =
HashMap {(0,1):2,(0,2):2,(1,2):2}

这是scala版本:

val arr = Array.range(0, 1000)
val counter = scala.collection.mutable.HashMap[(Int, Int), Int](  )
arr.combinations(2).toArray.map{
    row=>
        val key = (row(0), row(1))
        if (!counter.contains(key)) {
            counter(key) = 1
        }
        else {
            counter(key) += 1
        }
}
assert(counter.size == 499500)

Scala版本2:

val counter = arr.combinations(2).map(x => ((x(0),x(1)), 1)).toArray
.groupBy(_._1).mapValues(_.map(_._2).sum)

这是python版本:

import itertools    
arr = range(0, 1000)
combs = list(itertools.combinations(arr, 2))
counter = dict()
for key in combs:
    try:
        counter[key] += 1
    except KeyError:
        counter[key] = 1
assert len(counter) == 499500

两个scala版本的成本为9秒,而python版本的成本为1秒 我认为我肯定在做错代码,但我想不出其他方法来改进它(我对scala很新)。

另外,我使用mutable.HashMap的原因是我想减少内存使用量。

任何帮助都将不胜感激,谢谢。

2 个答案:

答案 0 :(得分:0)

您需要将arr转换为并行集合。理想情况下,对于RDD。因此,创建一个spark上下文,从下面的数组中获取RDD,然后对其进行操作。

val arr: RDD[Int] = sparkContext.parallelize(Array.range(0, 1000))

你真的应该查找some tutorials

答案 1 :(得分:0)

问题在于集合中的combine方法。它创建了一个效率不高的迭代器。我没有使用联合收割机创建了另一个x10更快的样本:

  def time[R](block: => R): R = {
    val t0 = System.currentTimeMillis()
    val result = block    // call-by-name
    val t1 = System.currentTimeMillis()
    println("Elapsed time: " + (t1 - t0) + "ms")
    result
  }

  val arr = Array.range(0, 1000).toList

  def combinations2[A](input: List[A]): Iterator[(A, A)] =
    input.tails.flatMap(_ match {
      case h :: t => t.iterator.map((h, _))
      case Nil => Iterator.empty
    })

  val counter = scala.collection.mutable.HashMap[(Int, Int), Int](  )
  time {
    combinations2(arr).foreach {
      row =>
        val key = row
        if (!counter.contains(key)) {
          counter(key) = 1
        }
        else {
          counter(key) += 1
        }
    }
    assert(counter.size == 499500)
  }

检查出来