线程安全地转换可变映射中的值

时间:2013-08-09 14:29:59

标签: scala map concurrency thread-safety

假设我想在Scala中使用可变映射来跟踪我看到一些字符串的次数。在单线程环境中,这很容易:

import scala.collection.mutable.{ Map => MMap }

class Counter {
  val counts = MMap.empty[String, Int].withDefaultValue(0)

  def add(s: String): Unit = counts(s) += 1
}

不幸的是,这不是线程安全的,因为getupdate不会原子地发生。

Concurrent mapsa few atomic operations添加到可变映射API中,但不是我需要的那个,它看起来像这样:

def replace(k: A, f: B => B): Option[B]

我知道我可以使用ScalaSTMTMap

import scala.concurrent.stm._

class Counter {
  val counts =  TMap.empty[String, Int]

  def add(s: String): Unit = atomic { implicit txn =>
    counts(s) = counts.get(s).getOrElse(0) + 1
  }
}

但是(现在)这仍然是一个额外的依赖。其他选项包括actor(另一个依赖项),同步(可能效率较低)或Java atomic referencesless idiomatic)。

一般来说,我会避免Scala中的可变地图,但我偶尔会需要这种东西,而且最近我使用了STM方法(而不是只是交叉我的手指,希望我不会被咬通过天真​​的解决方案)。

我知道这里有许多权衡(额外依赖性与性能与清晰度等),但在Scala 2.10中是否存在类似“正确”解决此问题的答案?

4 个答案:

答案 0 :(得分:10)

这个怎么样?假设你现在不需要一般的replace方法,只需要一个计数器。

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

object CountedMap {
  private val counts = new ConcurrentHashMap[String, AtomicInteger]

  def add(key: String): Int = {
    val zero = new AtomicInteger(0)
    val value = Option(counts.putIfAbsent(key, zero)).getOrElse(zero)
    value.incrementAndGet
  }
}

你获得比在整个地图上同步更好的性能,并且你也获得了原子增量。

答案 1 :(得分:3)

最简单的解决方案绝对是同步。如果没有太多的争用,性能可能不会那么糟糕。

否则,您可以尝试汇总自己的STM replace实现。这样的事情可能会这样:

object ConcurrentMapOps {
  private val rng = new util.Random
  private val MaxReplaceRetryCount = 10
  private val MinReplaceBackoffTime: Long = 1
  private val MaxReplaceBackoffTime: Long = 20
}
implicit class ConcurrentMapOps[A, B]( val m: collection.concurrent.Map[A,B] ) {
  import ConcurrentMapOps._
  private def replaceBackoff() {
    Thread.sleep( (MinReplaceBackoffTime + rng.nextFloat * (MaxReplaceBackoffTime - MinReplaceBackoffTime) ).toLong ) // A bit crude, I know
  }

  def replace(k: A, f: B => B): Option[B] = {
    m.get( k ) match {
      case None => return None
      case Some( old ) =>
        var retryCount = 0
        while ( retryCount <= MaxReplaceRetryCount ) {
          val done = m.replace( k, old, f( old ) )
          if ( done ) {
            return Some( old )
          }
          else {         
            retryCount += 1
            replaceBackoff()
          }
        }
        sys.error("Could not concurrently modify map")
    }
  }
}

请注意,碰撞问题已本地化为给定密钥。如果两个线程访问相同的映射但处理不同的键,则不会发生冲突,并且替换操作将始终第一次成功。如果检测到碰撞,我们会等待一段时间(随机的时间量,以便最大限度地减少线程对同一个键永远战斗的可能性),然后再试一次。

我不能保证这是生产就绪的(我现在就把它扔掉了),但这可能就行了。

UPDATE :当然(正如IonuţG。Stan所指出的那样),如果你想要的只是增加/减少一个值,java的ConcurrentHashMap已经提供了无锁操作方式。 如果您需要一个将转换函数作为参数的更通用的replace方法,我的上述解决方案将适用。

答案 2 :(得分:2)

如果你的地图只是坐在那里作为一个val,你会遇到麻烦。如果它符合您的使用案例,我会推荐类似

的内容
class Counter {
  private[this] myCounts = MMap.empty[String, Int].withDefaultValue(0)
  def counts(s: String) = myCounts.synchronized { myCounts(s) }
  def add(s: String) = myCounts.synchronized { myCounts(s) += 1 }
  def getCounts = myCounts.synchronized { Map[String,Int]() ++ myCounts }
}

用于低争用用法。对于高争用,您应该使用旨在支持此类使用的并发映射(例如java.util.concurrent.ConcurrentHashMap)并将值包装在AtomicWhatever中。

答案 3 :(得分:2)

如果您可以使用基于未来的界面:

trait SingleThreadedExecutionContext {
  val ec = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor())
}

class Counter extends SingleThreadedExecutionContext {
  private val counts = MMap.empty[String, Int].withDefaultValue(0)

  def get(s: String): Future[Int] = future(counts(s))(ec)

  def add(s: String): Future[Unit] = future(counts(s) += 1)(ec)
}

测试将如下所示:

class MutableMapSpec extends Specification {

  "thread safe" in {

    import ExecutionContext.Implicits.global

    val c = new Counter
    val testData = Seq.fill(16)("1")
    await(Future.traverse(testData)(c.add))
    await(c.get("1")) mustEqual 16
  }
}