如何并行化排序?

时间:2017-11-11 07:12:16

标签: multithreading sorting parallel-processing julia

我想排序很多东西。

Julia的标准库排序是单线程的。 如何利用我的多核机器更快地对事物进行排序?

2 个答案:

答案 0 :(得分:12)

以下是使用(实验性的) Base.Threads 线程模块的解决方案。

使用pmap(等)进行分布式并行的解决方案类似。虽然我认为进程间通信开销会伤害你。

这个想法是用块(每个线程一个)对它进行排序,因此每个线程可以完全独立,只需要处理它的块。

然后合并这些预先排序的块。

这是一个众所周知的合并排序列表的问题。另请参阅其他questions

在开始之前,不要忘记通过设置环境变量 JULIA_NUM_THREADS 来设置自己的多线程。

这是我的代码:

using Base.Threads

function blockranges(nblocks, total_len)
    rem = total_len % nblocks
    main_len = div(total_len, nblocks)

    starts=Int[1]
    ends=Int[]
    for ii in 1:nblocks
        len = main_len
        if rem>0
            len+=1
            rem-=1
        end
        push!(ends, starts[end]+len-1)
        push!(starts, ends[end] + 1)
    end
    @assert ends[end] == total_len
    starts[1:end-1], ends
end

function threadedsort!(data::Vector)
    starts, ends = blockranges(nthreads(), length(data))

    # Sort each block
    @threads for (ss, ee) in collect(zip(starts, ends))
        @inbounds sort!(@view data[ss:ee])
    end


    # Go through each sorted block taking out the smallest item and putting it in the new array
    # This code could maybe be optimised. see https://stackoverflow.com/a/22057372/179081
    ret = similar(data) # main bit of allocation right here. avoiding it seems expensive.
    # Need to not overwrite data we haven't read yet
    @inbounds for ii in eachindex(ret)
        minblock_id = 1
        ret[ii]=data[starts[1]]
        @inbounds for blockid in 2:endof(starts) # findmin allocates a lot for some reason, so do the find by hand. (maybe use findmin! ?)
            ele = data[starts[blockid]]
            if ret[ii] > ele
                ret[ii] = ele
                minblock_id = blockid
            end
        end
        starts[minblock_id]+=1 # move the start point forward
        if starts[minblock_id] > ends[minblock_id]
            deleteat!(starts, minblock_id)
            deleteat!(ends, minblock_id)
        end
    end
    data.=ret  # copy back into orignal as we said we would do it inplace
    return data
end

我做了一些基准测试:

using Plots
function evaluate_timing(range)
    sizes = Int[]
    threadsort_times = Float64[]
    sort_times = Float64[]
        for sz in 2.^collect(range)
            data_orig = rand(Int, sz)
            push!(sizes, sz)

            data = copy(data_orig)
            push!(sort_times,       @elapsed sort!(data))

            data = copy(data_orig)
            push!(threadsort_times, @elapsed threadedsort!(data))

            @show (sz, sort_times[end], threadsort_times[end])
    end
    return sizes, threadsort_times, sort_times
end

sizes, threadsort_times, sort_times = evaluate_timing(0:28)
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"])
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"], xscale=:log10, yscale=:log10)

我的结果:使用8个帖子。

plot normal scale plot loglog scale

我发现交叉点低得惊人,略高于1024。 注意可以忽略最初的长时间 - 这是第一次运行时为JIT编译的代码。

奇怪的是,这些结果在使用BenchmarkTools时无法重现。 基准测试工具将停止计算初始时间。 但是,当我使用上面的基准代码中的常规定时代码时,它们会非常一致地重现。 我猜它正在做一些杀死多线程的事情

非常感谢@xiaodai,他在我的分析代码中指出了一个错误

答案 1 :(得分:1)

我已经进一步测试过,只有1%的项目是唯一的,enter image description here也来自1:1_000_000。结果低于

enter image description here     function evaluate_timing_w_repeats(范围)         sizes = Int []         threadsort_times = Float64 []         sort_times = Float64 []             对于sz在2. ^收集(范围)                 data_orig = rand(rand(Int,sz÷100),sz)                 推!(尺寸,sz)

            data = copy(data_orig)
            push!(sort_times,       @elapsed sort!(data))

            data = copy(data_orig)
            push!(threadsort_times, @elapsed threadedsort!(data))

            @show (sz, sort_times[end], threadsort_times[end])
    end
    return sizes, threadsort_times, sort_times
end

sizes, threadsort_times, sort_times = evaluate_timing_w_repeats(7:28)
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"])
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"], xscale=:log10, yscale=:log10)
savefig("sort_with_repeats.png")

function evaluate_timing1m(range)
    sizes = Int[]
    threadsort_times = Float64[]
    sort_times = Float64[]
        for sz in 2.^collect(range)
            data_orig = rand(1:1_000_000, sz)
            push!(sizes, sz)

            data = copy(data_orig)
            push!(sort_times,       @elapsed sort!(data))

            data = copy(data_orig)
            push!(threadsort_times, @elapsed threadedsort!(data))

            @show (sz, sort_times[end], threadsort_times[end])
    end
    return sizes, threadsort_times, sort_times
end

sizes, threadsort_times, sort_times = evaluate_timing1m(7:28)
plot(sizes, [threadsort_times sort_times]; title="Sorting Time", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"])
plot(sizes, [threadsort_times sort_times]; title="Sorting Time sampel from 1:1_000_000", ylabel="time(s)", xlabel="number of elements", label=["threadsort!" "sort!"], xscale=:log10, yscale=:log10)
savefig("sort1m.png")