我可以将这个小Python脚本与全局字典并行化吗?

时间:2015-04-26 17:02:21

标签: python

我有这个问题,我的MBA使用Python3需要大约2.8秒。因为在它的核心我们有一个缓存字典,我认为哪个调用首先命中缓存并不重要,所以也许我可以通过线程获得一些收益。不过,我无法弄明白。这比我通常提出的问题要高一些,但是有人可以引导我完成这个问题的并行化过程吗?

id
username
auth_key
password_hash
password_reset_token
email
status
created_at
access_token

或者,如果它是一个不好的候选人,为什么?

3 个答案:

答案 0 :(得分:1)

如果您的工作负载是CPU密集型的,那么Python中的线程将无法得到很好的提升。这是因为由于GIL(全局解释器锁定),一次只有一个线程实际上会使用处理器。

但是,如果您的工作负载受I / O限制(例如,等待来自网络请求的响应),则线程会给您一点提升,因为如果您的线程被阻塞等待网络响应,则另一个线程可以做有用的工作。

正如HDN所提到的,使用多处理将有所帮助 - 这使用多个Python解释器来完成工作。

我接近这个的方法是将迭代次数除以您计划创建的进程数。例如,如果您创建4个流程,请为每个流程提供1000000/4个工作片段。

最后,您需要汇总每个流程的结果并应用max()来获得结果。

答案 1 :(得分:0)

线程不会在性能提升方面给你太多,因为它不会绕过Global Interpreter Lock,它只会在任何给定时刻运行一个线程。由于上下文切换,它实际上甚至可能会减慢你的速度。

如果您希望在Python中利用并行化的性能,那么您将不得不使用多处理来实际利用多个核心。

答案 2 :(得分:0)

我设法在单核上将您的代码加速了 16.5x 倍,请进一步阅读。

如前所述,由于 Global Interpreter Lock,多线程对纯 Python 没有任何改进。

关于多处理 - 有两种选择 1) 实现共享字典并直接从不同进程读取/写入它。 2) 将值的范围分割成部分并解决不同进程的单独子范围的任务,然后从所有进程的答案中取最大值。

第一个选项会很慢,因为在你的代码中读/写字典是主要的耗时操作,使用进程间共享字典会减慢 5 倍,而多核没有任何改进。

第二个选项会带来一些改进,但也不是很好,因为不同的进程会多次重新计算相同的值。仅当您拥有非常多的内核或在集群中使用许多单独的机器时,此选项才会有相当大的改进。

我决定采用另一种方法来改进您的任务(选项 3) - 使用 Numba 并进行其他优化。我的解决方案也适用于选项 2(子范围的并行化)。

Numba 是 Just-in-Time 编译器和优化器,它将纯 Python 代码转换为优化的 C++,然后编译为机器代码。 Numba 通常可以提供 10 到 100 倍的加速。

要使用 numba 运行代码,您只需要安装 pip install numba(目前 Python 版本 <= 3.8 支持 Numba,很快也会支持 3.9!)。

我所做的所有改进都在单核上提供了 16.5x 倍的加速(例如,如果在您的算法中,某些范围内的时间为 64 秒,那么在我的代码中为 4 秒)。

我不得不重写您的代码,算法和想法与您的相同,但我将算法设为非递归(因为 Numba 不能很好地处理递归)并且还使用列表而不是字典来处理不太大的值。< /p>

我的基于 numba 的单核版本有时可能会使用太多内存,这仅仅是因为 cs 参数控制使用列表而不是字典的阈值,目前此 cs 设置为stop * 10(在代码中搜索),如果您没有太多内存,只需将其设置为例如stop * 2(但不少于 stop * 1)。我有 16GB 的内存,即使上限为 64000000,程序也能正常运行。

另外除了Numba代码我实现了C++解决方案,它的速度似乎和Numba一样,这意味着Numba做得很好! C++ 代码位于 Python 代码之后。

我对您的算法 (solve_py()) 和我的 (solve_nm()) 进行了计时测量并比较了它们。时间列在代码之后。

仅供参考,我也使用我的 numba 解决方案做了多核处理版本,但它没有比单核版本有任何改进,甚至速度变慢。这一切都是因为多核版本多次计算相同的值。多机版可能会有明显的提升,但多核可能不会。

由于免费在线服务器上的内存有限,下面的在线试用链接只允许运行小范围!

Try it online!

import time, threading, time, numba

def solve_py(start, stop):
    even = lambda n: n%2==0
    next_collatz = lambda n: n//2 if even(n) else 3*n+1

    cache = {1: 1}
    def collatz_chain_length(n):
        if n not in cache: cache[n] = 1 + collatz_chain_length(next_collatz(n))
        return cache[n]

    for n in range(start, stop):
        collatz_chain_length(n)

    r = max(range(start, stop), key = cache.get)
    return r, cache[r]

@numba.njit(cache = True, locals = {'n': numba.int64, 'l': numba.int64, 'zero': numba.int64})
def solve_nm(start, stop):
    zero, l, cs = 0, 0, stop * 10
    ns = [zero] * 10000
    cache_lo = [zero] * cs
    cache_lo[1] = 1
    cache_hi = {zero: zero}
    for n in range(start, stop):
        if cache_lo[n] != 0:
            continue
        nsc = 0
        while True:
            if n < cs:
                cg = cache_lo[n]
            else:
                cg = cache_hi.get(n, zero)
            if cg != 0:
                l = 1 + cg
                break
            ns[nsc] = n
            nsc += 1
            n = (n >> 1) if (n & 1) == 0 else 3 * n + 1
        for i in range(nsc - 1, -1, -1):
            if ns[i] < cs:
                cache_lo[ns[i]] = l
            else:
                cache_hi[ns[i]] = l
            l += 1
    maxn, maxl = 0, 0
    for k in range(start, stop):
        v = cache_lo[k]
        if v > maxl:
            maxn, maxl = k, v
    return maxn, maxl

if __name__ == '__main__':
    solve_nm(1, 100000) # heat-up, precompile numba
    for stop in [1000000, 2000000, 4000000, 8000000, 16000000, 32000000, 64000000]:
        tr, resr = None, None
        for is_nm in [False, True]:
            if stop > 16000000 and not is_nm:
                continue
            tb = time.time()
            res = (solve_nm if is_nm else solve_py)(1, stop)
            te = time.time()
            print(('py', 'nm')[is_nm], 'limit', stop, 'time', round(te - tb, 2), 'secs', end = '')
            if not is_nm:
                resr, tr = res, te - tb
                print(', n', res[0], 'len', res[1])
            else:
                if tr is not None:
                    print(', boost', round(tr / (te - tb), 2))
                    assert resr == res, (resr, res)
                else:
                    print(', n', res[0], 'len', res[1])

输出:

py limit 1000000 time 3.34 secs, n 837799 len 525
nm limit 1000000 time 0.19 secs, boost 17.27
py limit 2000000 time 6.72 secs, n 1723519 len 557
nm limit 2000000 time 0.4 secs, boost 16.76
py limit 4000000 time 13.47 secs, n 3732423 len 597
nm limit 4000000 time 0.83 secs, boost 16.29
py limit 8000000 time 27.32 secs, n 6649279 len 665
nm limit 8000000 time 1.68 secs, boost 16.27
py limit 16000000 time 55.42 secs, n 15733191 len 705
nm limit 16000000 time 3.48 secs, boost 15.93
nm limit 32000000 time 7.38 secs, n 31466382 len 706
nm limit 64000000 time 16.83 secs, n 63728127 len 950

与 Numba 相同算法的 C++ 版本位于下面:

Try it online!

#include <cstdint>
#include <vector>
#include <unordered_map>
#include <tuple>
#include <iostream>
#include <stdexcept>
#include <chrono>

typedef int64_t i64;

static std::tuple<i64, i64> Solve(i64 start, i64 stop) {
    i64 cs = stop * 10, n = 0, l = 0, nsc = 0;
    std::vector<i64> cache_lo(cs), ns(10000);
    cache_lo[1] = 1;
    std::unordered_map<i64, i64> cache_hi;
    for (i64 i = start; i < stop; ++i) {
        if (cache_lo[i] != 0)
            continue;
        n = i;
        nsc = 0;
        while (true) {
            i64 cg = 0;
            if (n < cs)
                cg = cache_lo[n];
            else {
                auto it = cache_hi.find(n);
                if (it != cache_hi.end())
                    cg = it->second;
            }
            if (cg != 0) {
                l = 1 + cg;
                break;
            }
            ns.at(nsc) = n;
            ++nsc;
            n = (n & 1) ? 3 * n + 1 : (n >> 1);
        }
        for (i64 i = nsc - 1; i >= 0; --i) {
            i64 n = ns[i];
            if (n < cs)
                cache_lo[n] = l;
            else
                cache_hi[n] = l;
            ++l;
        }
    }
    i64 maxn = 0, maxl = 0;
    for (size_t i = start; i < stop; ++i)
        if (cache_lo[i] > maxl) {
            maxn = i;
            maxl = cache_lo[i];
        }
    return std::make_tuple(maxn, maxl);
}

int main() {
    try {
        for (auto stop: std::vector<i64>({1000000, 2000000, 4000000, 8000000, 16000000, 32000000, 64000000})) {
            auto tb = std::chrono::system_clock::now();
            auto r = Solve(1, stop);
            auto te = std::chrono::system_clock::now();
            std::cout << "cpp limit " << stop
                << " time " << double(std::chrono::duration_cast<std::chrono::milliseconds>(te - tb).count()) / 1000.0 << " secs"
                << ", n " << std::get<0>(r) << " len " << std::get<1>(r) << std::endl;
        }
        return 0;
    } catch (std::exception const & ex) {
        std::cout << "Exception: " << ex.what() << std::endl;
        return -1;
    }
}

输出:

cpp limit 1000000 time 0.17 secs, n 837799 len 525
cpp limit 2000000 time 0.357 secs, n 1723519 len 557
cpp limit 4000000 time 0.757 secs, n 3732423 len 597
cpp limit 8000000 time 1.571 secs, n 6649279 len 665
cpp limit 16000000 time 3.275 secs, n 15733191 len 705
cpp limit 32000000 time 7.112 secs, n 31466382 len 706
cpp limit 64000000 time 17.165 secs, n 63728127 len 950