来自采访:
int fn(int a, int b)
{
int sum = 0;
for (int i = a * 4; i > 0; i--)
{
sum += b * i * i;
}
return sum;
}
如何进一步优化此代码?我知道有一个总和公式,但我不认为记住这样的公式是面试官想要的。那么,你会如何优化它?
编辑:感谢chqrlie,faivvy,asimes和Ap31的建议和答案。所以我想现在有三种方法可以优化它:
在这三个答案中,我可能会选择1和3,因为它们可以应用于具有相似结构的所有类型的代码。你应该提到有一个公式可以作为奖金,但我怀疑公式是否是采访者想要的。
还有其他建议吗?
答案 0 :(得分:3)
公式:1 * 1 + 2 * 2 + ... + n * n = n(n + 1)(2n + 1)/ 6
int fn(int a, int b)
{
a <<= 2;
return (a*(a + 1)*((a << 1) + 1) / 6) * b;
}
这就是你想要的吗?
答案 1 :(得分:3)
函数fn
计算b
乘以4*a
之前的平方和,除非a
为负数。
1
到n
的平方和可以计算为 n(n + 1)(2n + 1)/ 6 。
这是一个C翻译:
int fn(int a, int b) {
if (a <= 0 || b == 0) {
return 0;
} else {
int n = a * 4;
return n * (n + 1) * (2 * n + 1) / 6 * b;
}
}
正如Ap31所指出的,clang足够精明以检测循环优化并将原始函数转换为直接计算,但它将上述代码编译为much more compact 16 assembly instructions(原始代码为36)。 / p>
为了避免中间结果可能出现溢出,这里有一个稍微不同的公式,不会计算更大的中间结果:
int fn(int a, int b) {
if (a <= 0 || b == 0) {
return 0;
} else {
if (a % 3 == 0)
return (a / 3) * (4 * a + 1) * (8 * a + 1) * b * 2;
else
return (4 * a + 1) * (8 * a + 1) / 3 * a * b * 2;
}
}
如果类型long long
大于int
,则更简单的替代方法是:
int fn(int a, int b) {
if (a <= 0 || b == 0) {
return 0;
} else {
unsigned long long n = a * 4;
return (int)(n * (n + 1) * (2 * n + 1) / 6 * b);
}
}
答案 2 :(得分:2)
面试官当然希望从@faivvy(和@chqrlie)答案中进行优化,你总是可以得出公式,或者只是说你知道它存在,你可以完全摆脱循环。
不要忘记一些常见的pifall:a
可能是否定的,a*a*(2*a + 1)
可能会溢出。
需要注意的另一件事是modern compilers can do this by themselves - 你也可以向面试官提及。
答案 3 :(得分:1)
正如@faivvy在他的回答中指出的那样,你可以尝试完全取消for循环
然而,另一种方法(正确处理否定a
)是执行循环展开,我将调用该函数fnUnroll
。如果您不熟悉循环展开,那么我们的想法是减少迭代次数并将值并行加总
正如评论中所提到的,每次迭代都不需要乘以b
,这可以在最后完成。我添加了另一个名为fnUnrollNoMult
的函数来显示此
#include <chrono>
#include <cstdlib>
#include <iostream>
int fn(int a, int b) {
int sum = 0;
for (int i = a * 4; i > 0; i--)
sum += b * i * i;
return sum;
}
int fnUnroll(int a, int b) {
// Set up some number of accumulators, I picked 4
int sum0 = 0;
int sum1 = 0;
int sum2 = 0;
int sum3 = 0;
int i = 1;
int limit = a * 4;
// Sum 4 values in parallel
for ( ; i < limit; i += 4) {
sum0 += b * i * i;
sum1 += b * (i + 1) * (i + 1);
sum2 += b * (i + 2) * (i + 2);
sum3 += b * (i + 3) * (i + 3);
}
// Handle the remainder (if any)
for ( ; i < limit; i++)
sum0 += b * i + i;
// Sum the accumulators
return sum0 + sum1 + sum2 + sum3;
}
int fnUnrollNoMult(int a, int b) {
int sum0 = 0;
int sum1 = 0;
int sum2 = 0;
int sum3 = 0;
// Remove b from the loops
int i = 1;
int limit = a * 4;
for ( ; i < limit; i += 4) {
sum0 += i * i;
sum1 += (i + 1) * (i + 1);
sum2 += (i + 2) * (i + 2);
sum3 += (i + 3) * (i + 3);
}
for ( ; i < limit; i++)
sum0 += i + i;
// Handle b here
return b * (sum0 + sum1 + sum2 + sum3);
}
int main(int argc, char** argv) {
// Expects two arguments: a and b
if (argc != 3) {
std::cout << "Usage: " << argv[0] << " <int> <int>\n";
return 1;
}
int a = atoi(argv[1]);
int b = atoi(argv[2]);
// This is just to demonstrate correctness
for (int i = 0; i < 100; i++)
for (int j = 0; j < 100; j++)
if (
fn(i, j) != fnUnroll(i, j) ||
fn(i, j) != fnUnrollNoMult(i, j)
) {
std::cout << "Not equal: " << i << ", " << j << std::endl;
return 1;
}
// Benchmark
using namespace std::chrono;
{
auto start = high_resolution_clock::now();
int result = fn(a, b);
auto stop = high_resolution_clock::now();
std::cout << "fn value: " << result << std::endl;
std::cout << "fn nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
}
{
auto start = high_resolution_clock::now();
int result = fnUnroll(a, b);
auto stop = high_resolution_clock::now();
std::cout << "fnUnroll value: " << result << std::endl;
std::cout << "fnUnroll nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
}
{
auto start = high_resolution_clock::now();
int result = fnUnrollNoMult(a, b);
auto stop = high_resolution_clock::now();
std::cout << "fnUnrollNoMult value: " << result << std::endl;
std::cout << "fnUnrollNoMult nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
}
return 0;
}
下面的程序需要两个代表a
和b
的参数。下面我将程序编译为g++ -std=c++14 foo.cpp -O3
,并将这些结果用于某些a
值:
./a.out 1 2
fn value: 60
fn nanos: 373
fnUnroll value: 60
fnUnroll nanos: 209
fnUnrollNoMult value: 60
fnUnrollNoMult nanos: 157
./a.out 1000 2
fn value: -267004960
fn nanos: 3509
fnUnroll value: -267004960
fnUnroll nanos: 2820
fnUnrollNoMult value: -267004960
fnUnrollNoMult nanos: 1568
./a.out 1000000 2
fn value: -619707648
fn nanos: 3137685
fnUnroll value: -619707648
fnUnroll nanos: 2387840
fnUnrollNoMult value: -619707648
fnUnrollNoMult nanos: 1220519