是否有一个API来竞争N个线程(或N个线程上的N个闭包)完成?

时间:2018-04-15 19:42:59

标签: multithreading concurrency rust

如果有几个线程以Output值完成,我该如何获得生成的第一个Output?理想情况下,仍然可以按照它们生成的顺序获得剩余的Output s,并记住某些线程可能会或可能不会终止。

示例:

struct Output(i32);

fn main() {
    let mut spawned_threads = Vec::new();

    for i in 0..10 {
        let join_handle: ::std::thread::JoinHandle<Output> = ::std::thread::spawn(move || {
            // pretend to do some work that takes some amount of time
            ::std::thread::sleep(::std::time::Duration::from_millis(
                (1000 - (100 * i)) as u64,
            ));
            Output(i) // then pretend to return the `Output` of that work
        });
        spawned_threads.push(join_handle);
    }

    // I can do this to wait for each thread to finish and collect all `Output`s
    let outputs_in_order_of_thread_spawning = spawned_threads
        .into_iter()
        .map(::std::thread::JoinHandle::join)
        .collect::<Vec<::std::thread::Result<Output>>>();

    // but how would I get the `Output`s in order of completed threads?
}

我可以使用共享队列/频道/类似方法自己解决问题,但有内置API或现有库可以更优雅地解决这个用例吗?

我正在寻找像这样的API:

fn race_threads<A: Send>(
    threads: Vec<::std::thread::JoinHandle<A>>
) -> (::std::thread::Result<A>, Vec<::std::thread::JoinHandle<A>>) {
    unimplemented!("so far this doesn't seem to exist")
}

Rayonjoin是我能找到的最接近的,但是a)它只能比赛2个闭包而不是任意数量的闭包,并且b)线程池w /工作窃取方法对于我可能永远运行的一些闭包的用例没有意义。)

可以使用来自How to check if a thread has finished in Rust?的指针来解决这个用例,就像使用MPSC通道可以解决这个用例一样,但是在这里我是一个干净的API来竞争n个线程(或者失败,n线程n关闭。

2 个答案:

答案 0 :(得分:1)

使用condition variable

可以解决这些问题
use std::sync::{Arc, Condvar, Mutex};

#[derive(Debug)]
struct Output(i32);

enum State {
    Starting,
    Joinable,
    Joined,
}

fn main() {
    let pair = Arc::new((Mutex::new(Vec::new()), Condvar::new()));
    let mut spawned_threads = Vec::new();

    let &(ref lock, ref cvar) = &*pair;
    for i in 0..10 {
        let my_pair = pair.clone();
        let join_handle: ::std::thread::JoinHandle<Output> = ::std::thread::spawn(move || {
            // pretend to do some work that takes some amount of time
            ::std::thread::sleep(::std::time::Duration::from_millis(
                (1000 - (100 * i)) as u64,
            ));

            let &(ref lock, ref cvar) = &*my_pair;
            let mut joinable = lock.lock().unwrap();
            joinable[i] = State::Joinable;
            cvar.notify_one();
            Output(i as i32) // then pretend to return the `Output` of that work
        });
        lock.lock().unwrap().push(State::Starting);
        spawned_threads.push(Some(join_handle));
    }

    let mut should_stop = false;
    while !should_stop {
        let locked = lock.lock().unwrap();
        let mut locked = cvar.wait(locked).unwrap();

        should_stop = true;
        for (i, state) in locked.iter_mut().enumerate() {
            match *state {
                State::Starting => {
                    should_stop = false;
                }
                State::Joinable => {
                    *state = State::Joined;
                    println!("{:?}", spawned_threads[i].take().unwrap().join());
                }
                State::Joined => (),
            }
        }
    }
}

(playground link)

我并不是说这是最简单的方法。每次子线程完成时,条件变量将唤醒主线程。该列表可以显示每个线程的状态,如果一个(即将)完成,则可以加入。

答案 1 :(得分:0)

,没有这样的API。

您已经有多种方法可以解决您的问题:

有时在编程时,你必须超越预制块。这应该是编程的有趣部分。我鼓励你接受它。使用可用组件创建理想的API,并将其发布到crates.io

我真的不知道频道版本有多糟糕:

use std::{sync::mpsc, thread, time::Duration};

#[derive(Debug)]
struct Output(i32);

fn main() {
    let (tx, rx) = mpsc::channel();

    for i in 0..10 {
        let tx = tx.clone();
        thread::spawn(move || {
            thread::sleep(Duration::from_millis((1000 - (100 * i)) as u64));
            tx.send(Output(i)).unwrap();
        });
    }
    // Don't hold on to the sender ourselves
    // Otherwise the loop would never terminate
    drop(tx);

    for r in rx {
        println!("{:?}", r);
    }
}