我有一个原子计数器(std::atomic<uint32_t> count
),可以按顺序将值递增到多个线程。
uint32_t my_val = ++count;
在我得到my_val
之前,我想确保增量不会溢出(即:回到0)
if (count == std::numeric_limits<uint32_t>::max())
throw std::runtime_error("count overflow");
我认为这是一个天真的检查,因为如果在增加计数器之前由两个线程执行检查,则增加的第二个线程将返回0
if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
throw std::runtime_error("count overflow");
uint32_t my_val = ++count; // before either gets here - possible overflow
因此我想我需要使用CAS
操作来确保当我增加计数器时,我确实防止了可能的溢出。
所以我的问题是:
max
)?我的代码(带有工作范例)如下:
#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>
std::atomic<uint16_t> count;
uint16_t get_val() // called by multiple threads
{
uint16_t my_val;
do
{
my_val = count;
// make sure I get the next value
if (count.compare_exchange_strong(my_val, my_val + 1))
{
// if I got the next value, make sure we don't overflow
if (my_val == std::numeric_limits<uint16_t>::max())
{
count = std::numeric_limits<uint16_t>::max() - 1;
throw std::runtime_error("count overflow");
}
break;
}
// if I didn't then check if there are still numbers available
if (my_val == std::numeric_limits<uint16_t>::max())
{
count = std::numeric_limits<uint16_t>::max() - 1;
throw std::runtime_error("count overflow");
}
// there are still numbers available, so try again
}
while (1);
return my_val + 1;
}
void run()
try
{
while (1)
{
if (get_val() == 0)
exit(1);
}
}
catch(const std::runtime_error& e)
{
// overflow
}
int main()
{
while (1)
{
count = 1;
std::thread a(run);
std::thread b(run);
std::thread c(run);
std::thread d(run);
a.join();
b.join();
c.join();
d.join();
std::cout << ".";
}
return 0;
}
答案 0 :(得分:6)
是的,您需要使用CAS
操作。
std::atomic<uint16_t> g_count;
uint16_t get_next() {
uint16_t new_val = 0;
do {
uint16_t cur_val = g_count; // 1
if (cur_val == std::numeric_limits<uint16_t>::max()) { // 2
throw std::runtime_error("count overflow");
}
new_val = cur_val + 1; // 3
} while(!std::atomic_compare_exchange_weak(&g_count, &cur_val, new_val)); // 4
return new_val;
}
这个想法如下:一旦g_count == std::numeric_limits<uint16_t>::max()
,get_next()
函数将始终抛出异常。
步骤:
答案 1 :(得分:2)
如果效率是一个大问题,那么我建议不要对支票这么严格。我猜测在正常情况下使用溢出不会是一个问题,但你真的需要完整的65K范围(你的例子使用uint16)?
如果你假设你运行的线程数有一些最大值会更容易。这是一个合理的限制,因为没有程序具有无限数量的并发性。因此,如果您拥有N
个主题,则只需将溢出限制降低到65K - N
即可。要比较你是否溢出,你不需要CAS:
uint16_t current = count.load(std::memory_order_relaxed);
if( current >= (std::numeric_limits<uint16_t>::max() - num_threads - 1) )
throw std::runtime_error("count overflow");
count.fetch_add(1,std::memory_order_relaxed);
这会产生软溢出情况。如果两个线程同时到达它们,它们都可能通过,但这没关系,因为count变量本身永远不会溢出。此时任何未来到达都将在逻辑上溢出(直到计数再次减少)。
答案 2 :(得分:1)
在我看来,仍然存在竞争条件,其中count
将暂时设置为0,以便另一个线程将看到0值。
假设count
位于std::numeric_limits<uint16_t>::max()
,并且两个线程尝试获取递增的值。在线程1执行count.compare_exchange_strong(my_val, my_val + 1)
的那一刻,count被设置为0,这就是线程2在线程1有机会恢复{{1}之前调用并完成get_val()
时将看到的内容} count
。