优化parallel_for实现

时间:2015-12-11 10:44:37

标签: c++ c++11 concurrency parallel-for

我有一些使用微软PPL进行parallel_for循环的代码,然后我不得不将它移到Linux和Mac上,这使我制作了自己的版本。它做了应该做的事情并且它具有良好的性能,但它仍然比其他相同的PPL parallel_for循环慢了约20%。

我或许应该提到通常会执行10到10万次迭代,但每次迭代只有几个平方根和乘法。但它必须运行得非常快,因为它适用于交互式应用程序。

对于C ++ 11来说仍然是新手,所以如果有经验的人能够看一下我的实现并给出一些反馈意见,为什么它不是一直存在以及可以改进什么,我会很高兴。

template<size_t THREADS_PER_CORE = 1>
void parallel_forMine(size_t start, size_t end, const std::function<void(size_t)> &userLambda)
{
    int threadCount = std::thread::hardware_concurrency()*THREADS_PER_CORE;

    int blockSize = (end - start) / threadCount;
    if (blockSize*threadCount < end - start)
        blockSize++;

    std::vector<std::future<void>> futures;

    int blockStart = start;
    int blockEnd = blockStart + blockSize;
    if (blockEnd > end) blockEnd = end;

    for (int threadIndex = 0; threadIndex < threadCount; threadIndex++)
    {
        futures.push_back(std::move(std::async(std::launch::async, [blockStart, blockEnd, &userLambda]
        {
            for (size_t i = blockStart; i < blockEnd; ++i)
            {
                userLambda(i);
            }
        })));

        blockStart += blockSize;
        blockEnd = blockStart + blockSize;
        if (blockStart >= end) break;
        if (blockEnd > end) blockEnd = end;
    }

    for (std::future<void> &f: futures)
        f.get();
}

完整的测试代码如下。

#include "stdafx.h" //nothing in there in this test
#include <ppl.h>
#include <chrono>
#include <iostream>
#include <vector>
#include <future>

template<size_t THREADS_PER_CORE = 1>
void parallel_forMine(size_t start, size_t end, const std::function<void(size_t)> &userLambda)
{
int threadCount = std::thread::hardware_concurrency()*THREADS_PER_CORE;

int blockSize = (end - start) / threadCount;
if (blockSize*threadCount < end - start)
    blockSize++;

std::vector<std::future<void>> futures;

int blockStart = start;
int blockEnd = blockStart + blockSize;
if (blockEnd > end) blockEnd = end;

for (int threadIndex = 0; threadIndex < threadCount; threadIndex++)
{
    futures.push_back(std::move(std::async(std::launch::async, [blockStart, blockEnd, &userLambda]
    {
        for (size_t i = blockStart; i < blockEnd; ++i)
        {
            userLambda(i);
        }
    })));

    blockStart += blockSize;
    blockEnd = blockStart + blockSize;
    if (blockStart >= end) break;
    if (blockEnd > end) blockEnd = end;
}

for (std::future<void> &f: futures)
    f.get();
}



int main()
{
    //serial execution
    std::vector<double> valueSerial(1000);
    auto startSerial = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < 1000; i++)
        for (int j = 0; j < 1000000; j++)
            valueSerial[i] += sqrt(abs(cos(sin(sqrt(i)))));
    auto durationSerial = (std::chrono::high_resolution_clock::now() - startSerial).count() / 1000;
    std::cout << durationSerial << " Serial" << std::endl;


//PPL parallel for
std::vector<double> valueParallelForPPL(1000);
auto startParallelForPPL = std::chrono::high_resolution_clock::now();
Concurrency::parallel_for(size_t(0), size_t(1000), [&](size_t i)
{
    for (int j = 0; j < 1000000; j++)
        valueParallelForPPL[i] += sqrt(abs(cos(sin(sqrt(i)))));
});
auto durationParallelForPPL = (std::chrono::high_resolution_clock::now() - startParallelForPPL).count() / 1000;
std::cout << durationParallelForPPL << " PPL parallel for"<<std::endl;


//my parallel for
std::vector<double> valueParallelFor(1000);
auto startParallelFor = std::chrono::high_resolution_clock::now();
parallel_forMine(0, 1000, [&](size_t i)
{
    for (int j = 0; j < 1000000; j++)
        valueParallelFor[i] += sqrt(abs(cos(sin(sqrt(i)))));
});
auto durationParallelFor = (std::chrono::high_resolution_clock::now() - startParallelFor).count() / 1000;
std::cout << durationParallelFor << " My parallel for"<<std::endl;


//only really to make sure the compiler doesn't optimize everything away
for (int i = 0; i < valueSerial.size();i++)
    if (valueSerial[i] != valueParallelFor[i] || valueParallelFor[i]!= valueParallelForPPL[i])
        std::cout << "error";


std::cin.get();

return 0;
}

0 个答案:

没有答案