多个生产者,多个消费者锁定免费队列

时间:2016-11-03 10:55:39

标签: c++ multithreading data-structures concurrency lock-free

我只是通过 Anthony Williams 阅读C++ Concurrency in Action而我正在尝试将关于无锁数据结构的章节中的最后一个示例拼凑起来,这是针对多个生产者的和多个消费者锁定免费队列。我测试时,我设法放在一起的代码崩溃了。我正在发布MPMC队列的代码和我正在使用的测试程序。

lock_free_queue_mpmc.hpp

#pragma once

#include <memory>

template <class T>
class lock_free_queue_mpmc
{
public:
  lock_free_queue_mpmc()
  {
    counted_node_ptr counted_node;
    counted_node.ptr = new node;
    counted_node.external_count = 1;

    head_.store(counted_node);
    tail_.store(head_);
  }

  lock_free_queue_mpmc(const lock_free_queue_mpmc& other) = delete;
  lock_free_queue_mpmc& operator=(const lock_free_queue_mpmc& other) = delete;

  ~lock_free_queue_mpmc()
  {
    counted_node_ptr old_head = head_.load();
    while (node* const old_node = old_head.ptr)
    {
      head_.store(old_node->next);
      delete old_node;
      old_head = head_.load();
    }
  }

  void push(const T& new_value)
  {
    std::unique_ptr<T> new_data(new T(new_value));
    counted_node_ptr new_next;
    new_next.ptr = new node;
    new_next.external_count = 1;
    counted_node_ptr old_tail = tail_.load();

    while (true)
    {
      increase_external_count(tail_, old_tail);
      T* old_data = nullptr;
      if (old_tail.ptr->data.compare_exchange_strong(old_data, new_data.get()))
      {
        counted_node_ptr old_next = {0};
        if (!old_tail.ptr->next.compare_exchange_strong(old_next, new_next))
        {
          delete new_next.ptr;
          new_next = old_next;
        }
        set_new_tail(old_tail, new_next);
        new_data.release();
        break;
      }
      else
      {
        counted_node_ptr old_next = {0};
        if(old_tail.ptr->next.compare_exchange_strong(old_next, new_next))
        {
          old_next = new_next;
          new_next.ptr = new node;
        }
        set_new_tail(old_tail, old_next);
      }
    }
  }

  std::unique_ptr<T> pop()
  {
    counted_node_ptr old_head = head_.load(std::memory_order_relaxed);
    while (true)
    {
      increase_external_count(head_, old_head);
      node* const ptr = old_head.ptr;
      if(ptr == tail_.load().ptr)
      {
        return std::unique_ptr<T>();
      }
      counted_node_ptr next = ptr->next.load();
      if (head_.compare_exchange_strong(old_head,next))
      {
        T* const res = ptr->data.exchange(nullptr);
        free_external_counter(old_head);
        return std::unique_ptr<T>(res);
      }
      ptr->release_ref();
    }
  }

private:
  struct node;

  struct counted_node_ptr
  {
    int external_count;
    node* ptr;
  };

  struct node_counter
  {
    unsigned internal_count    : 30;
    unsigned external_counters : 2;
  };

  struct node
  {
    std::atomic<T*> data;
    std::atomic<node_counter> count;
    std::atomic<counted_node_ptr> next;

    node()
    {
      node_counter new_count;
      new_count.internal_count    = 0;
      new_count.external_counters = 2;
      count.store(new_count);

      counted_node_ptr new_next;
      new_next.ptr            = nullptr;
      new_next.external_count = 0;
      next.store(new_next);
    }

    void release_ref()
    {
      node_counter old_counter = count.load(std::memory_order_relaxed);
      node_counter new_counter;

      do
      {
        new_counter=old_counter;
        --new_counter.internal_count;
      }
      while(!count.compare_exchange_strong(old_counter, new_counter,
                                           std::memory_order_acquire,
                                           std::memory_order_relaxed));

      if(!new_counter.internal_count && !new_counter.external_counters)
      {
        delete this;
      }
    }
  };

private:
  void set_new_tail(counted_node_ptr& old_tail,
                    const counted_node_ptr& new_tail)
  {
    node* const current_tail_ptr = old_tail.ptr;

    while (!tail_.compare_exchange_weak(old_tail, new_tail) &&
           old_tail.ptr == current_tail_ptr);

    if(old_tail.ptr == current_tail_ptr)
    {
      free_external_counter(old_tail);
    }
    else
    {
      current_tail_ptr->release_ref();
    }
  }

  static void increase_external_count(std::atomic<counted_node_ptr>& counter,
                                      counted_node_ptr& old_counter)
  {
    counted_node_ptr new_counter;

    do
    {
      new_counter = old_counter;
      ++new_counter.external_count;
    }
    while(!counter.compare_exchange_strong(old_counter, new_counter,
                                           std::memory_order_acquire,
                                           std::memory_order_relaxed));

    old_counter.external_count = new_counter.external_count;
  }

  static void free_external_counter(counted_node_ptr& old_node_ptr)
  {
    node* const ptr = old_node_ptr.ptr;
    const int count_increase = old_node_ptr.external_count - 2;
    node_counter old_counter= ptr->count.load(std::memory_order_relaxed);
    node_counter new_counter;

    do
    {
      new_counter = old_counter;
      --new_counter.external_counters;
      new_counter.internal_count += count_increase;
    }
    while(!ptr->count.compare_exchange_strong(old_counter, new_counter,
                                              std::memory_order_acquire,
                                              std::memory_order_relaxed));

    if(!new_counter.internal_count && !new_counter.external_counters)
    {
      delete ptr;
    }
  }

private:

  std::atomic<counted_node_ptr> head_;
  std::atomic<counted_node_ptr> tail_;

};

的main.cpp

#include <iostream>
#include <vector>
#include <algorithm>
#include <thread>
#include <atomic>

#include "lock_free_queue_mpmc.hpp"

using namespace std;

constexpr size_t PRODUCER_THREADS_COUNT = 100;
constexpr size_t CONSUMER_THREADS_COUNT = 100;
constexpr size_t ELEMENTS_COUNT = 100'000'000;

std::atomic<size_t> valuesProduced = { 0 };
std::atomic<size_t> valuesConsumed = { 0 };
std::atomic<uint64_t> sum = { 0 };

lock_free_queue_mpmc<int> g_queue;

// -----------------------------------------------------

void producerFunc()
{
  while (valuesProduced < ELEMENTS_COUNT)
  {
    ++valuesProduced;
    g_queue.push(1);
  }
}

// -----------------------------------------------------

void consumerFunc()
{
  while (valuesConsumed < ELEMENTS_COUNT)
  {
    auto value = g_queue.pop();
    if (value)
    {
      sum += *value;
      ++valuesConsumed;
    }
  }
}

// -----------------------------------------------------

int main()
{
  auto timeBegin = chrono::high_resolution_clock::now();

// -----------------------------------------------------

  vector<thread> producerThreads;
  producerThreads.reserve(PRODUCER_THREADS_COUNT);

  for (size_t i = 0; i < PRODUCER_THREADS_COUNT; ++i)
  {
    producerThreads.emplace_back(producerFunc);
  }

// -----------------------------------------------------

  vector<thread> consumerThreads;
  consumerThreads.reserve(CONSUMER_THREADS_COUNT);

  for (size_t i = 0; i < CONSUMER_THREADS_COUNT; ++i)
  {
    consumerThreads.emplace_back(consumerFunc);
  }

// -----------------------------------------------------

  for_each(producerThreads.begin(), producerThreads.end(), mem_fn(&thread::join));
  for_each(consumerThreads.begin(), consumerThreads.end(), mem_fn(&thread::join));

// -----------------------------------------------------

  auto timeEnd = chrono::high_resolution_clock::now();
  cout << "Sum: " << sum << endl;
  cout << "Time: " << chrono::duration_cast<chrono::milliseconds>(
            timeEnd - timeBegin).count() << endl;

  return 0;
}

GDB回溯

#0  0x00007ffff70fcc37 in __GI_raise (sig=sig@entry=6) at ../nptl/sysdeps/unix/sysv/linux/raise.c:56
#1  0x00007ffff7100028 in __GI_abort () at abort.c:89
#2  0x00007ffff71392a4 in __libc_message (do_abort=do_abort@entry=1, fmt=fmt@entry=0x7ffff72476b0 "*** Error in `%s': %s: 0x%s ***\n") at ../sysdeps/posix/libc_fatal.c:175
#3  0x00007ffff714555e in malloc_printerr (ptr=<optimized out>, str=0x7ffff72477e0 "double free or corruption (out)", action=1) at malloc.c:4996
#4  _int_free (av=<optimized out>, p=<optimized out>, have_lock=0) at malloc.c:3840
#5  0x00000000004029fb in std::default_delete<int>::operator() (this=0x7fffee5bde20, __ptr=0x7ffff00009c0) at /usr/include/c++/6/bits/unique_ptr.h:76
#6  0x0000000000401e0f in std::unique_ptr<int, std::default_delete<int> >::~unique_ptr (this=0x7fffee5bde20, __in_chrg=<optimized out>) at /usr/include/c++/6/bits/unique_ptr.h:236
#7  0x000000000040127f in consumerFunc () at /home/bobeff/data/Dropbox/sources/books_sources/cpp_concurrency/21_lock_free_queue_mpmc/main.cpp:38
#8  0x0000000000403fc7 in std::_Bind_simple<void (*())()>::_M_invoke<>(std::_Index_tuple<>) (this=0x619de8) at /usr/include/c++/6/functional:1400
#9  0x0000000000403f67 in std::_Bind_simple<void (*())()>::operator()() (this=0x619de8) at /usr/include/c++/6/functional:1389
#10 0x0000000000403ea2 in std::thread::_State_impl<std::_Bind_simple<void (*())()> >::_M_run() (this=0x619de0) at /usr/include/c++/6/thread:196
#11 0x00007ffff773287f in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#12 0x00007ffff7bc4184 in start_thread (arg=0x7fffee5be700) at pthread_create.c:312
#13 0x00007ffff71c037d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:111

有人能找到代码的错误吗?

0 个答案:

没有答案