同步3个线程以打印顺序输出

时间:2019-06-24 07:38:32

标签: multithreading c++11

在一次采访中有人问我这个问题。我很笨。 因此,我决定学习一些多线程技术,并希望找到该问题的答案。

我需要使用3个线程来打印输出:01020304050607 .....

  1. 线程1:打印0
  2. 线程2:打印奇数
  3. 线程3:打印偶数
#include <iostream>
#include <thread>
#include <mutex>
#include <condition_variable>

std::mutex m;
std::condition_variable cv1, cv2, cv3;

int count = 0;

void printzero(int end)
{
    while (count <= end)
    {
        std::unique_lock<std::mutex> lock(m);
        cv1.wait(lock);
        std::cout << 0 << " ";
        ++count;
        if (count % 2 == 1)
        {
            lock.unlock();
            cv2.notify_one();
        }
        else
        {
            lock.unlock();
            cv3.notify_one();
        }
    }
}

void printodd(int end)
{
    while (count <= end)
    {
        std::unique_lock<std::mutex> lock(m);
        cv2.wait(lock);
        if (count % 2 == 1)
        {
            std::cout << count << " ";
            ++count;
            lock.unlock();
            cv1.notify_one();
        }
    }
}

void printeven(int end)
{
    while (count <= end)
    {
        std::unique_lock<std::mutex> lock(m);
        cv3.wait(lock);
        if (count % 2 == 0)
        {
            std::cout << count << " ";
            ++count;
            lock.unlock();
            cv1.notify_one();
        }
    }
}

int main()
{
    int end = 10;

    std::thread t3(printzero, end);
    std::thread t1(printodd, end);
    std::thread t2(printeven, end);

    cv1.notify_one();

    t1.join();
    t2.join();
    t3.join();

    return 0;
}

我的解决方案似乎处于僵局。我什至不确定逻辑是否正确。请帮助

1 个答案:

答案 0 :(得分:1)

您的代码有几个问题。为了使它起作用,这是您需要做的:

  1. 修改您的while (count <= end)支票。不同步地读取count是未定义的行为(UB)。
  2. std::condition_variable::wait使用正确的谓词。没有谓词的代码问题:
    • 如果在{em> notify_one之前调用wait,则该通知将丢失。在最坏的情况下,mainnotify_one的调用是在线程开始运行之前执行的。结果,所有线程可能会无限期等待。
    • Spurious wakeups可能会中断您的程序流程。另请参见cppreference.com on std::condition variable
  3. 使用std::flush(请确保使用)。

我玩了很多您的代码。在下面,您找到了我应用建议的修复程序的版本。此外,我还尝试了一些其他想法。

#include <cassert>

#include <condition_variable>
#include <functional>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>

// see the `std::mutex` for an example how to avoid global variables

std::condition_variable cv_zero{};
std::condition_variable cv_nonzero{};

bool done = false;
int next_digit = 1;
bool need_zero = true;

void print_zero(std::mutex& mt) {
  while(true) {// do not read shared state without holding a lock
    std::unique_lock<std::mutex> lk(mt);
    auto pred = [&] { return done || need_zero; };
    cv_zero.wait(lk, pred);
    if(done) break;

    std::cout << 0 << "\t"
              << -1 << "\t"// prove that it works
              << std::this_thread::get_id() << "\n"// prove that it works
              << std::flush;

    need_zero = false;

    lk.unlock();
    cv_nonzero.notify_all();// Let the other threads decide which one
                            // wants to proceed. This is probably less
                            // efficient, but preferred for
                            // simplicity.
  }
}

void print_nonzero(std::mutex& mt, int end, int n, int N) {
// Example for `n` and `N`: Launch `N == 2` threads with this
// function. Then the thread with `n == 1` prints all odd numbers, and
// the one with `n == 0` prints all even numbers.
  assert(N >= 1 && "number of 'nonzero' threads must be positive");
  assert(n >= 0 && n < N && "rank of this nonzero thread must be valid");

  while(true) {// do not read shared state without holding a lock
    std::unique_lock<std::mutex> lk(mt);
    auto pred = [&] { return done || (!need_zero && next_digit % N == n); };
    cv_nonzero.wait(lk, pred);
    if(done) break;

    std::cout << next_digit << "\t"
              << n << "\t"// prove that it works
              << std::this_thread::get_id() << "\n"// prove that it works
              << std::flush;

// Consider the edge case of `end == INT_MAX && next_digit == INT_MAX`.
// -> You need to check *before* incrementing in order to avoid UB.

    assert(next_digit <= end);
    if(next_digit == end) {
      done = true;
      cv_zero.notify_all();
      cv_nonzero.notify_all();
      break;
    }

    ++next_digit;
    need_zero = true;

    lk.unlock();
    cv_zero.notify_one();
  }
}

int main() {
  int end = 10;
  int N = 2;// number of threads for `print_nonzero`

  std::mutex mt{};// example how to pass by reference (avoiding globals)

  std::thread t_zero(print_zero, std::ref(mt));

// Create `N` `print_nonzero` threads with `n` in [0, `N`).
  std::vector<std::thread> ts_nonzero{};
  for(int n=0; n<N; ++n) {
// Note that it is important to pass `n` by value.
    ts_nonzero.emplace_back(print_nonzero, std::ref(mt), end, n, N);
  }

  t_zero.join();
  for(auto&& t : ts_nonzero) {
    t.join();
  }
}