优化求和代码

时间:2017-04-14 03:03:31

标签: c++ c performance optimization

来自采访:

int fn(int a, int b)
{
    int sum = 0; 
    for (int i = a * 4; i > 0; i--)
    {
        sum += b * i * i;
    } 
    return sum;
}

如何进一步优化此代码?我知道有一个总和公式,但我不认为记住这样的公式是面试官想要的。那么,你会如何优化它?

编辑:感谢chqrlie,faivvy,asimes和Ap31的建议和答案。所以我想现在有三种方法可以优化它:

  1. 我们可以在返回时执行此操作,而不是在每次迭代中乘以b。
  2. 用公式替换for循环:n *(n + 1)*(2 * n + 1)/ 6 * b。 A simple derivation
  3. 使用循环展开。请参阅asimes的帖子。
  4. 在这三个答案中,我可能会选择1和3,因为它们可以应用于具有相似结构的所有类型的代码。你应该提到有一个公式可以作为奖金,但我怀疑公式是否是采访者想要的。

    还有其他建议吗?

4 个答案:

答案 0 :(得分:3)

公式:1 * 1 + 2 * 2 + ... + n * n = n(n + 1)(2n + 1)/ 6

int fn(int a, int b)
{
    a <<= 2;
    return (a*(a + 1)*((a << 1) + 1) / 6) * b;
}

这就是你想要的吗?

答案 1 :(得分:3)

函数fn计算b乘以4*a之前的平方和,除非a为负数。

1n的平方和可以计算为 n(n + 1)(2n + 1)/ 6

这是一个C翻译:

int fn(int a, int b) {
    if (a <= 0 || b == 0) {
        return 0;
    } else {
        int n = a * 4;
        return n * (n + 1) * (2 * n + 1) / 6 * b;
    }
}

正如Ap31所指出的,clang足够精明以检测循环优化并将原始函数转换为直接计算,但它将上述代码编译为much more compact 16 assembly instructions(原始代码为36)。 / p>

为了避免中间结果可能出现溢出,这里有一个稍微不同的公式,不会计算更大的中间结果:

int fn(int a, int b) {
    if (a <= 0 || b == 0) {
        return 0;
    } else {
        if (a % 3 == 0)
            return (a / 3) * (4 * a + 1) * (8 * a + 1) * b * 2;
        else
            return (4 * a + 1) * (8 * a + 1) / 3 * a * b * 2;
    }
}

如果类型long long大于int,则更简单的替代方法是:

int fn(int a, int b) {
    if (a <= 0 || b == 0) {
        return 0;
    } else {
        unsigned long long n = a * 4;
        return (int)(n * (n + 1) * (2 * n + 1) / 6 * b);
    }
}

答案 2 :(得分:2)

面试官当然希望从@faivvy(和@chqrlie)答案中进行优化,你总是可以得出公式,或者只是说你知道它存在,你可以完全摆脱循环。

不要忘记一些常见的pifall:a可能是否定的,a*a*(2*a + 1)可能会溢出。

需要注意的另一件事是modern compilers can do this by themselves - 你也可以向面试官提及。

答案 3 :(得分:1)

正如@faivvy在他的回答中指出的那样,你可以尝试完全取消for循环

然而,另一种方法(正确处理否定a)是执行循环展开,我将调用该函数fnUnroll。如果您不熟悉循环展开,那么我们的想法是减少迭代次数并将值并行加总

正如评论中所提到的,每次迭代都不需要乘以b,这可以在最后完成。我添加了另一个名为fnUnrollNoMult的函数来显示此

#include <chrono>
#include <cstdlib>
#include <iostream>

int fn(int a, int b) {
    int sum = 0;
    for (int i = a * 4; i > 0; i--)
        sum += b * i * i;
    return sum;
}

int fnUnroll(int a, int b) {
    // Set up some number of accumulators, I picked 4
    int sum0 = 0;
    int sum1 = 0;
    int sum2 = 0;
    int sum3 = 0;

    int i = 1;
    int limit = a * 4;

    // Sum 4 values in parallel
    for ( ; i < limit; i += 4) {
        sum0 += b * i * i;
        sum1 += b * (i + 1) * (i + 1);
        sum2 += b * (i + 2) * (i + 2);
        sum3 += b * (i + 3) * (i + 3);
    }

    // Handle the remainder (if any)
    for ( ; i < limit; i++)
        sum0 += b * i + i;

    // Sum the accumulators
    return sum0 + sum1 + sum2 + sum3;
}

int fnUnrollNoMult(int a, int b) {
    int sum0 = 0;
    int sum1 = 0;
    int sum2 = 0;
    int sum3 = 0;

    // Remove b from the loops
    int i = 1;
    int limit = a * 4;
    for ( ; i < limit; i += 4) {
        sum0 += i * i;
        sum1 += (i + 1) * (i + 1);
        sum2 += (i + 2) * (i + 2);
        sum3 += (i + 3) * (i + 3);
    }
    for ( ; i < limit; i++)
        sum0 += i + i;

    // Handle b here
    return b * (sum0 + sum1 + sum2 + sum3);
}

int main(int argc, char** argv) {
    // Expects two arguments: a and b
    if (argc != 3) {
        std::cout << "Usage: " << argv[0] << " <int> <int>\n";
        return 1;
    }

    int a = atoi(argv[1]);
    int b = atoi(argv[2]);

    // This is just to demonstrate correctness
    for (int i = 0; i < 100; i++)
        for (int j = 0; j < 100; j++)
            if (
                fn(i, j) != fnUnroll(i, j) ||
                fn(i, j) != fnUnrollNoMult(i, j)
            ) {
                std::cout << "Not equal: " << i << ", " << j << std::endl;
                return 1;
            }

    // Benchmark
    using namespace std::chrono;
    {
        auto start = high_resolution_clock::now();
        int result = fn(a, b);
        auto stop  = high_resolution_clock::now();
        std::cout << "fn value:             " << result << std::endl;
        std::cout << "fn nanos:             " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
    }
    {
        auto start = high_resolution_clock::now();
        int result = fnUnroll(a, b);
        auto stop  = high_resolution_clock::now();
        std::cout << "fnUnroll value:       " << result << std::endl;
        std::cout << "fnUnroll nanos:       " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
    }
    {
        auto start = high_resolution_clock::now();
        int result = fnUnrollNoMult(a, b);
        auto stop  = high_resolution_clock::now();
        std::cout << "fnUnrollNoMult value: " << result << std::endl;
        std::cout << "fnUnrollNoMult nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
    }

    return 0;
}

下面的程序需要两个代表ab的参数。下面我将程序编译为g++ -std=c++14 foo.cpp -O3,并将这些结果用于某些a值:

./a.out 1 2
fn value:             60
fn nanos:             373
fnUnroll value:       60
fnUnroll nanos:       209
fnUnrollNoMult value: 60
fnUnrollNoMult nanos: 157
./a.out 1000 2
fn value:             -267004960
fn nanos:             3509
fnUnroll value:       -267004960
fnUnroll nanos:       2820
fnUnrollNoMult value: -267004960
fnUnrollNoMult nanos: 1568
./a.out 1000000 2
fn value:             -619707648
fn nanos:             3137685
fnUnroll value:       -619707648
fnUnroll nanos:       2387840
fnUnrollNoMult value: -619707648
fnUnrollNoMult nanos: 1220519