防止std :: atomic溢出

时间:2013-05-28 05:47:19

标签: c++ c++11 atomic

我有一个原子计数器(std::atomic<uint32_t> count),可以按顺序将值递增到多个线程。

uint32_t my_val = ++count;

在我得到my_val之前,我想确保增量不会溢出(即:回到0)

if (count == std::numeric_limits<uint32_t>::max())
    throw std::runtime_error("count overflow");

我认为这是一个天真的检查,因为如果在增加计数器之前由两个线程执行检查,则增加的第二个线程将返回0

if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
    throw std::runtime_error("count overflow");
uint32_t my_val = ++count;       // before either gets here - possible overflow

因此我想我需要使用CAS操作来确保当我增加计数器时,我确实防止了可能的溢出。

所以我的问题是:

  • 我的实施是否正确?
  • 它是否尽可能高效(特别是我需要两次检查max)?

我的代码(带有工作范例)如下:

#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>

std::atomic<uint16_t> count;

uint16_t get_val() // called by multiple threads
{
    uint16_t my_val;
    do
    {
        my_val = count;

        // make sure I get the next value

        if (count.compare_exchange_strong(my_val, my_val + 1))
        {
            // if I got the next value, make sure we don't overflow

            if (my_val == std::numeric_limits<uint16_t>::max())
            {
                count = std::numeric_limits<uint16_t>::max() - 1;
                throw std::runtime_error("count overflow");
            }
            break;
        }

        // if I didn't then check if there are still numbers available

        if (my_val == std::numeric_limits<uint16_t>::max())
        {
            count = std::numeric_limits<uint16_t>::max() - 1;
            throw std::runtime_error("count overflow");
        }

        // there are still numbers available, so try again
    }
    while (1);
    return my_val + 1;
}

void run()
try
{
    while (1)
    {
        if (get_val() == 0)
            exit(1);
    }

}
catch(const std::runtime_error& e)
{
    // overflow
}

int main()
{
    while (1)
    {
        count = 1;
        std::thread a(run);
        std::thread b(run);
        std::thread c(run);
        std::thread d(run);
        a.join();
        b.join();
        c.join();
        d.join();
        std::cout << ".";
    }
    return 0;
}

3 个答案:

答案 0 :(得分:6)

是的,您需要使用CAS操作。

std::atomic<uint16_t> g_count;

uint16_t get_next() {
   uint16_t new_val = 0;
   do {
      uint16_t cur_val = g_count;                                            // 1
      if (cur_val == std::numeric_limits<uint16_t>::max()) {                 // 2
          throw std::runtime_error("count overflow");
      }
      new_val = cur_val + 1;                                                 // 3
   } while(!std::atomic_compare_exchange_weak(&g_count, &cur_val, new_val)); // 4

   return new_val;
}

这个想法如下:一旦g_count == std::numeric_limits<uint16_t>::max()get_next()函数将始终抛出异常。

步骤:

  1. 获取计数器的当前值
  2. 如果是最大值,则抛出异常(不再提供数字)
  3. 获取新值作为当前值的增量
  4. 尝试以原子方式设置新值。如果我们未能设置它(它已经由另一个线程完成),请再试一次。

答案 1 :(得分:2)

如果效率是一个大问题,那么我建议不要对支票这么严格。我猜测在正常情况下使用溢出不会是一个问题,但你真的需要完整的65K范围(你的例子使用uint16)?

如果你假设你运行的线程数有一些最大值会更容易。这是一个合理的限制,因为没有程序具有无限数量的并发性。因此,如果您拥有N个主题,则只需将溢出限制降低到65K - N即可。要比较你是否溢出,你不需要CAS:

uint16_t current = count.load(std::memory_order_relaxed);
if( current >= (std::numeric_limits<uint16_t>::max() - num_threads - 1) )
    throw std::runtime_error("count overflow");
count.fetch_add(1,std::memory_order_relaxed);

这会产生软溢出情况。如果两个线程同时到达它们,它们都可能通过,但这没关系,因为count变量本身永远不会溢出。此时任何未来到达都将在逻辑上溢出(直到计数再次减少)。

答案 2 :(得分:1)

在我看来,仍然存在竞争条件,其中count将暂时设置为0,以便另一个线程将看到0值。

假设count位于std::numeric_limits<uint16_t>::max(),并且两个线程尝试获取递增的值。在线程1执行count.compare_exchange_strong(my_val, my_val + 1)的那一刻,count被设置为0,这就是线程2在线程1有机会恢复{{1}之前调用并完成get_val()时将看到的内容} count