优化头痛 - 从查找表中删除if

时间:2011-09-12 13:34:54

标签: c++ arrays lookup

我正在尝试优化以下代码,这是我的应用程序中的瓶颈。 它的作用:它取双值value1和value2,并试图找到最大值,包括一个校正因子。如果两个值之间的差值大于5.0(LUT按比例10缩放),我可以取这两个值的最大值。如果差值小于5.0,我可以使用LUT中的校正因子。

有没有人知道这段代码有什么更好的风格?我不知道我在哪里浪费时间 - 是大量的ifs还是乘以10?

double value1, value2;
// Lookup Table scaled by 10 for (ln(1+exp(-abs(x)))), which is almost 0 for x > 5 and symmetrical around 0. LUT[0] is x=0.0, LUT[40] is x=4.0.
const logValue LUT[50] = { ... }

if (value1 > value2)
{
    if (value1 - value2 >= 5.0)
    {
        return value1;
    }
    else
    {
        return value1 + LUT[(uint8)((value1 - value2) * 10)];
    }
}
else
{
    if (value2 - value1 >= 5.0)
    {
        return value2;
    }
    else
    {
        return value2 + LUT[(uint8)((value2 - value1) * 10)];
    }
}

7 个答案:

答案 0 :(得分:2)

它可能同样落在两条路径上,导致处理器出现很多管道问题。

您是否尝试过分析?

我还建议尝试使用标准库,看看是否有帮助(例如,如果它能够使用和特定于处理器的指令):

double diff = std::fabs(value1 - value2);
double maxv = std::max(value1, value2);
return (diff >= 5.0) ? maxv : maxv + LUT[(uint8)((diff) * 10)];

答案 1 :(得分:2)

我可能已经编写了有点不同的代码来处理value2<value1情况:

if (value2 < value1) std::swap(value1, value2);
assert(value1 <= value2); // Assertion corrected
int diff = int((value2 - value1) * 10.0);
if (diff >= 50) diff = 49; // Integer comparison iso floating point
return value2 + LUT[diff];

答案 2 :(得分:2)

使用Excel进行几分钟的操作会产生一个近似等式,看起来它可能具有您所需的准确度,因此您可以完全取消查找表。您仍然需要一个条件来确保等式的参数保持在优化的范围内。

double diff = abs(value1 - value2);
double dmax = (value1 + value2 + diff) * 0.5; // same as (min+max+(max-min))/2
if (diff > 5.0)
    return dmax;
return dmax + 4.473865638/(2.611112371+diff) + 0.088190879*diff + -1.015046114;

P.S。我不保证这更快,只是它是一个不同的方法值得基准测试。

P.P.S。可以改变约束来得出稍微不同的常数,有很多变化。这是我做的另一套,你的表和公式之间的差异总是小于0.008,每个值也会小于前一个。

return dmax + 3.441318133/(2.296924445+diff) + 0.065529678*diff + -0.797081529;

修改:我测试了此代码(第二个公式),对100个0到10之间的一百万个随机数进行了100次传递,以及问题的原始代码MSalters currently accepted answer和蛮力实施max(value1,value2)+log(1.0+exp(-abs(value1-value2)))。我试用了双核AMD Athlon和Intel四核i7,结果大致一致。这是一个典型的运行:

  • 原文:1.32秒。
  • MSalters:1.13秒。
  • 我的:0.67秒。
  • 蛮力:4.50秒。
多年来,处理器的速度令人难以置信,现在它们可以快速地进行几次浮点乘法和除法,而不是在内存中查找值。这种方法不仅在现代x86上更快,而且更准确;方程中的近似误差远小于截断查找输入所引起的步进误差。

根据您的处理器和编译器,结果自然会有所不同;您的特定目标仍需要基准测试。

答案 3 :(得分:1)

我将假设在调用函数时,您很可能会得到必须使用查找表的部分,而不是>=5.0部分。在这种情况下,最好指导编译器。

double maxval = value1;
double difference_scaled = (value1-value2)*10;
if (difference < 0)
{
    difference = -difference;
    maxval = value2;
}
if (difference < 50)
    return maxval+LUT[(int)difference_scaled];
else
    return maxval;

试试这个并告诉我这是否会改善您的计划效果。

答案 4 :(得分:0)

此代码成为应用程序瓶颈的唯一原因是因为您多次调用它。你确定需要吗?也许代码中较高的算法可以更改为使用较少的比较?

答案 5 :(得分:0)

您在函数中计算了value1 - value2几次。只做一次。

施放到uint8_t也可能存在问题。就性能而言,用作从double到整数的转换的最佳整数类型是int,因为使用数组索引的最佳整数类型是int

max_value = value1;
diff = value1 - value2;
if (diff < 0.0) {
  max_value = value2;
  diff = -diff;
}

if (diff >= 5.0) {
  return max_value;
}
else {
  return max_value + LUT[(int)(diff * 10.0)];
}

请注意,上述内容可确保LUT索引介于0(含)和50(不包括)之间。这里不需要uint8_t

修改
经过一些变化后,这是一个相当快速的基于LUT的近似值log(exp(value1)+exp(value2))

#include <stdint.h>

// intptr_t *happens* to be fastest on my machine. YMMV.
typedef intptr_t IndexType;

double log_sum_exp (double value1, double value2, double *LUT) {
  double diff = value1 - value2;
  if (diff < 0.0) {
    value1 = value2;
    diff = -diff;
  }   
  IndexType idx = diff * 10.0;
  if (idx < 50) {
    value1 += LUT[idx];
  }   
  return value1;
}   

整数类型IndexType是加快速度的关键之一。我用clang和g ++进行了测试,两者都表明在我的计算机上转换为intptr_tlong)并使用intptr_t作为LUT的索引比其他整数类型更快。它比某些类型快得多。例如,unsigned long longuint8_t在我的计算机上的选择非常糟糕

该类型不仅仅是一个提示,至少对我使用的编译器而言。无论优化级别如何,这些编译器都完全按照代码告诉它做的关于从浮点类型到整数类型的转换。

通过将积分类型与50进行比较而不是将浮点类型与5.0进行比较,可以产生另一个减速带。

最后一个减速点:并非所有编译器都是相同的。 在我的计算机(YMMV)上,g++ -O3生成相当慢的代码(这个问题慢了25%!)而不是clang -O3,后者反过来生成的代码有点慢而不是clang -O4生成的。

我也使用了理性函数逼近法(类似于Mark Ransom的答案),但上面显然没有使用这种方法。

答案 6 :(得分:0)

我已经做了一些非常快速的测试,但请自行分析代码以验证效果。

LUT[]更改为静态变量让我加速了600%(从3.5秒增加到0.6秒)。这接近我使用的测试的绝对最小值(0.4s)。查看是否有效并重新配置以确定是否需要进一步优化。

作为参考,我只是简单地计算VC ++ 2010中这个循环的执行时间(内循环的1亿次迭代):

int Counter = 0;

for (double j = 0; j < 10; j += 0.001)
    {
        for (double i = 0; i < 10; i += 0.001)
        {
            ++Counter;
            Value1 += TestFunc1(i, j);
        }
    }