使用聚合的Scala并行频率计算不起作用

时间:2015-06-04 06:32:37

标签: multithreading scala concurrency parallel-processing aggregate

我正在通过“Scala for the Impatient”一书中的练习来学习Scala。请参阅以下问题以及我的答案和代码。我想知道我的答案是否正确。此代码也不起作用(所有频率均为1)。错误在哪里?

  

问题10:Harry Hacker将文件读入字符串并想要使用   并行收集以同时更新字母频率   部分字符串。他使用以下代码:

val frequencies = new scala.collection.mutable.HashMap[Char, Int]
for (c <- str.par) frequencies(c) = frequencies.getOrElse(c, 0) + 1
     

为什么这是一个糟糕的主意?他怎么能真正并行化   计算

我的回答: 这不是一个好主意,因为如果两个线程同时更新相同的频率,则结果是未定义的。

我的代码:

def parFrequency(str: String) = {
  str.par.aggregate(Map[Char, Int]())((m, c) => { m + (c -> (m.getOrElse(c, 0) + 1)) }, _ ++ _)
}

单元测试:

"Method parFrequency" should "return the frequency of each character in a string" in {
  val freq = parFrequency("harry hacker")

  freq should have size 8

  freq('h') should be(2) // fails
  freq('a') should be(2)
  freq('r') should be(3)
  freq('y') should be(1)
  freq(' ') should be(1)
  freq('c') should be(1)
  freq('k') should be(1)
  freq('e') should be(1)
}

修改: 阅读this线程后,我更新了代码。现在测试可以单独运行,但如果作为套件运行则会失败。

def parFrequency(str: String) = {
  val freq = ImmutableHashMap[Char, Int]()
  str.par.aggregate(freq)((_, c) => ImmutableHashMap(c -> 1), (m1, m2) => m1.merged(m2)({
      case ((k, v1), (_, v2)) => (k, v1 + v2)
  }))
}

编辑2: 请参阅下面的解决方案。

3 个答案:

答案 0 :(得分:0)

++未合并相同键的值。因此,当您合并地图时,您会获得(对于共享密钥)其中一个值(在这种情况下始终为1),而不是值的总和。

这有效:

def parFrequency(str: String) = {
  str.par.aggregate(Map[Char, Int]())((m, c) => { m + (c -> (m.getOrElse(c, 0) + 1)) },
  (a,b) => b.foldLeft(a){case (acc, (k,v))=> acc updated (k, acc.getOrElse(k,0) + v) })
} 

val freq = parFrequency("harry hacker")
//> Map(e -> 1, y -> 1, a -> 2,   -> 1, c -> 1, h -> 2, r -> 3, k -> 1)

foldLeft迭代其中一张地图,使用找到的键/值更新另一张地图。

答案 1 :(得分:0)

您在第一种情况下遇到麻烦,因为您自己检测到的是++运算符,它只是连接,丢弃了相同密钥的第二次出现。

现在在第二种情况下,你有(_, c) => ImmutableHashMap(c -> 1),它只会删除在seqop阶段找到地图的所有字符。

我的建议是使用特殊的合并操作扩展Map类型,与merged中的HashMap类似,并保留seqop阶段的第一个示例的收集:

implicit class MapUnionOps[K, V](m1: Map[K, V]) {
  def unionWith[V1 >: V](m2: Map[K, V1])(f: (V1, V1) => V1): Map[K, V1] = {
    val kv1 = m1.filterKeys(!m2.contains(_))
    val kv2 = m2.filterKeys(!m1.contains(_))
    val common = (m1.keySet & m2.keySet).toSeq map (k => (k, f(m1(k), m2(k))))
    (common ++ kv1 ++ kv2).toMap
  }
}

def parFrequency(str: String) = {
  str.par.aggregate(Map[Char, Int]())((m, c) => {m + (c -> (m.getOrElse(c, 0) + 1))}, (m1, m2) =>  (m1 unionWith m2)(_ + _))
}

或者您可以使用Paul的答案中的fold解决方案,但为了更好地表现每个合并,请选择较小的地图进行遍历:

implicit class MapUnionOps[K, V](m1: Map[K, V]) {
  def unionWith(m2: Map[K, V])(f: (V, V) => V): Map[K, V] =
    if (m2.size > m1.size) m2.unionWith(m1)(f)
    else m2.foldLeft(m1) {
      case (acc, (k, v)) => acc + (k -> acc.get(k).fold(v)(f(v, _)))
    }
}

答案 2 :(得分:0)

这似乎有效。我比这里提出的其他解决方案更喜欢它,因为:

  1. 代码少于implicit class,代码少于getOrElse foldLeft使用的代码。
  2. 它使用API​​中的merged函数来执行我想要的操作。
  3. 这是我自己的解决方案:)

    def parFrequency(str: String) = {
      val freq = ImmutableHashMap[Char, Int]()
      str.par.aggregate(freq)((_, c) => ImmutableHashMap(c -> 1), _.merged(_) {
        case ((k, v1), (_, v2)) => (k, v1 + v2)
      })
    }
    
  4. 感谢您抽出宝贵时间帮助我。