简化布尔逻辑表达式以使其更快

时间:2014-02-02 10:36:10

标签: c++ algorithm optimization boolean boolean-expression

考虑以下两个功能:

template <typename Type, Type Mask, class = typename std::enable_if<std::is_unsigned<Type>::value>::type>
inline bool function1(const Type n, const Type m)
{
    const Type diff = m-n;
    const Type msk = Mask & diff;
    return (n <= m) && ((!msk && !diff) || (msk && msk <= diff));
}

template <typename Type, Type Mask, class = typename std::enable_if<std::is_unsigned<Type>::value>::type>
inline bool function2(const Type n, const Type m)
{
    return (n <= m) && ((!(Mask & (m-n)) && !(m-n)) || ((Mask & (m-n)) && (Mask & (m-n)) <= (m-n)));
}

他们做同样的事情,除了第一个因为使用临时值而更具可读性(function2function1但是我用它们的原始值替换了临时值。) / p>

恰好function2function1快一点,并且由于我将在超级计算机上称之为十亿次,我想知道是否有更简单的布尔表达式产生完全相同的结果(Type将始终是无符号整数类型。)

3 个答案:

答案 0 :(得分:3)

表达式可以如下优化:

  1. (!msk && !diff)可以重写为!diff,因为如果两者都为零,则表达式为真,如果msk为零,则diff为零。
  2. 此外,diff总是不>= msk?也就是说,因为使用&无法增加msk的值。 (如果Type是无符号整数,则成立)
  3. 我改变了!diffmsk的顺序,因为msk!diff更常见,似乎有道理。
  4. 最后的表达是:

    (n <= m) && (msk || !diff)

    另一个等效表达式(由anatolyg建议)是:

    (n < m && (Mask && (m - n))) || (n == m)
    

答案 1 :(得分:2)

测试可能存在缺陷。

首次测试

#include <iostream>
#include <chrono>

template <unsigned Mask>
inline bool function1(const unsigned n, const unsigned m)
{
    const unsigned diff = m-n;
    const unsigned msk = Mask & diff;
    return (n <= m) && ((!msk && !diff) || (msk && msk <= diff));
}

template <unsigned Mask>
inline bool function2(const unsigned n, const unsigned m)
{
    return (n <= m) && ((!(Mask & (m-n)) && !(m-n)) || ((Mask & (m-n)) && (Mask & (m-n)) <= (m-n)));
}

template <unsigned Mask>
inline bool function3(const unsigned n, const unsigned m)
{
    if(m < n) return false;
    else if(m == n) return true;
    else return Mask & (m-n);
}

template <unsigned Mask>
inline bool function4(const unsigned n, const unsigned m)
{
    return (n < m && (Mask & (m-n))) || (n == m);
}

volatile unsigned a = std::rand();
volatile unsigned b = std::rand();
volatile bool result;

inline double duration(
    std::chrono::system_clock::time_point start,
    std::chrono::system_clock::time_point end)
{
    return double((end - start).count())
        / std::chrono::system_clock::period::den;
}

int main() {
    typedef bool (*Function)(const unsigned, const unsigned);
    const unsigned N = 4;
    std::chrono::system_clock::duration timing[N] = {};
    Function fn[] = {
        function1<0x1234>,
        function2<0x1234>,
        function3<0x1234>,
        function4<0x1234>,
    };
    for(unsigned i = 0; i < 10000; ++i) {
        for(unsigned j = 0; j < 100; ++j) {
            unsigned Loops = 100;
            for(unsigned f = 0; f < N; ++f) {
                auto start = std::chrono::system_clock::now();
                for(unsigned loop = 0; loop < Loops; ++loop) {
                    result = fn[f](a, b);
                }
                auto end = std::chrono::system_clock::now();
                timing[f] += (end-start);
            }
        }
    }

    for(unsigned i = 0; i < 4; ++i) {
        std::cout
            << "Timing " << i+1 << ": "
            << double(timing[i].count()) / std::chrono::system_clock::period::den
            << "\n";
    }
}

使用g ++ -std = c ++ 11 -O3编译 所示:

Timing 1: 0.435909
Timing 2: 0.435438
Timing 3: 0.435435
Timing 4: 0.435523

第二次测试:

#include <iostream>
#include <chrono>

inline bool function1(const unsigned Mask, const unsigned n, const unsigned m)
{
    const unsigned diff = m-n;
    const unsigned msk = Mask & diff;
    return (n <= m) && ((!msk && !diff) || (msk && msk <= diff));
}

inline bool function2(const unsigned Mask, const unsigned n, const unsigned m)
{
    return (n <= m) && ((!(Mask & (m-n)) && !(m-n)) || ((Mask & (m-n)) && (Mask & (m-n)) <= (m-n)));
}

inline bool function3(const unsigned Mask, const unsigned n, const unsigned m)
{
    if(m < n) return false;
    else if(m == n) return true;
    else return Mask & (m-n);
}

inline bool function4(const unsigned Mask, const unsigned n, const unsigned m)
{
    return (n < m && (Mask & (m-n))) || (n == m);
}

inline double duration(
    std::chrono::system_clock::time_point start,
    std::chrono::system_clock::time_point end)
{
    return double((end - start).count())
        / std::chrono::system_clock::period::den;
}

int main() {
    typedef bool (*Function)(const unsigned, const unsigned, const unsigned);
    const unsigned N = 4;
    std::chrono::system_clock::duration timing[N] = {};
    Function fn[] = {
        function1,
        function2,
        function3,
        function4,
    };
    const unsigned OuterLoops = 1000000;
    const unsigned InnerLoops = 100;
    const unsigned Samples = OuterLoops * InnerLoops;
    unsigned* M = new unsigned[Samples];
    unsigned* A = new unsigned[Samples];
    unsigned* B = new unsigned[Samples];
    for(unsigned i = 0; i < Samples; ++i) {
        M[i] = std::rand();
        A[i] = std::rand();
        B[i] = std::rand();
    }
    unsigned result[N];
    for(unsigned i = 0; i < OuterLoops; ++i) {
        for(unsigned f = 0; f < N; ++f) {
            auto start = std::chrono::system_clock::now();
            for(unsigned j = 0; j < InnerLoops; ++j) {
                unsigned index = i + j;
                unsigned mask = M[index];
                unsigned a = A[index];
                unsigned b = B[index];
                result[f] = fn[f](mask, a, b);
            }
            auto end = std::chrono::system_clock::now();
            timing[f] += (end-start);
        }
        for(unsigned f = 1; f < N; ++f) {
            if(result[0] != result[f]) {
                std::cerr << "Different Results\n";
                exit(-1);
            }
        }
    }

    for(unsigned i = 0; i < 4; ++i) {
        std::cout
            << "Timing " << i+1 << ": "
            << double(timing[i].count()) / std::chrono::system_clock::period::den
            << "\n";
    }
}

使用g ++ -std = c ++ 11 -O3编译 所示:

Timing 1: 0.763875
Timing 2: 0.738105
Timing 3: 0.518714
Timing 4: 0.785299

反汇编第二个函数(没有内联编译):

0000000000000000 <_Z9function1jjj>:
   0:   31 c0                   xor    %eax,%eax
   2:   39 f2                   cmp    %esi,%edx
   4:   72 10                   jb     16 <_Z9function1jjj+0x16>
   6:   29 f2                   sub    %esi,%edx
   8:   21 d7                   and    %edx,%edi
   a:   89 f8                   mov    %edi,%eax
   c:   09 d0                   or     %edx,%eax
   e:   74 18                   je     28 <_Z9function1jjj+0x28>
  10:   39 d7                   cmp    %edx,%edi
  12:   76 0c                   jbe    20 <_Z9function1jjj+0x20>
  14:   31 c0                   xor    %eax,%eax
  16:   f3 c3                   repz retq 
  18:   0f 1f 84 00 00 00 00    nopl   0x0(%rax,%rax,1)
  1f:   00 
  20:   85 ff                   test   %edi,%edi
  22:   74 f0                   je     14 <_Z9function1jjj+0x14>
  24:   0f 1f 40 00             nopl   0x0(%rax)
  28:   b8 01 00 00 00          mov    $0x1,%eax
  2d:   c3                      retq   
  2e:   66 90                   xchg   %ax,%ax

0000000000000030 <_Z9function2jjj>:
  30:   31 c0                   xor    %eax,%eax
  32:   39 d6                   cmp    %edx,%esi
  34:   77 0c                   ja     42 <_Z9function2jjj+0x12>
  36:   89 d1                   mov    %edx,%ecx
  38:   29 f1                   sub    %esi,%ecx
  3a:   21 cf                   and    %ecx,%edi
  3c:   75 0a                   jne    48 <_Z9function2jjj+0x18>
  3e:   39 f2                   cmp    %esi,%edx
  40:   74 0a                   je     4c <_Z9function2jjj+0x1c>
  42:   f3 c3                   repz retq 
  44:   0f 1f 40 00             nopl   0x0(%rax)
  48:   39 f9                   cmp    %edi,%ecx
  4a:   72 f6                   jb     42 <_Z9function2jjj+0x12>
  4c:   b8 01 00 00 00          mov    $0x1,%eax
  51:   c3                      retq   
  52:   66 66 66 66 66 2e 0f    data32 data32 data32 data32 nopw %cs:0x0(%rax,%rax,1)
  59:   1f 84 00 00 00 00 00 

0000000000000060 <_Z9function3jjj>:
  60:   31 c0                   xor    %eax,%eax
  62:   39 f2                   cmp    %esi,%edx
  64:   72 0f                   jb     75 <_Z9function3jjj+0x15>
  66:   74 08                   je     70 <_Z9function3jjj+0x10>
  68:   29 f2                   sub    %esi,%edx
  6a:   85 fa                   test   %edi,%edx
  6c:   0f 95 c0                setne  %al
  6f:   c3                      retq   
  70:   b8 01 00 00 00          mov    $0x1,%eax
  75:   f3 c3                   repz retq 
  77:   66 0f 1f 84 00 00 00    nopw   0x0(%rax,%rax,1)
  7e:   00 00 

0000000000000080 <_Z9function4jjj>:
  80:   39 d6                   cmp    %edx,%esi
  82:   73 0d                   jae    91 <_Z9function4jjj+0x11>
  84:   89 d1                   mov    %edx,%ecx
  86:   b8 01 00 00 00          mov    $0x1,%eax
  8b:   29 f1                   sub    %esi,%ecx
  8d:   85 f9                   test   %edi,%ecx
  8f:   75 05                   jne    96 <_Z9function4jjj+0x16>
  91:   39 d6                   cmp    %edx,%esi
  93:   0f 94 c0                sete   %al
  96:   f3 c3                   repz retq 

<强>设备:

英特尔®酷睿™i3-2310M CPU @ 2.10GHz×4 7.7 GiB

我的结论:

  • 正确分析算法(参见@George答案)
  • 在简单代码中表达优化算法,并对编译器进行微调优化。
  • 编写正确的测试用例(测量),但测量类型会影响结果。 (这里:第一个和第二个显示不同的结果) -

答案 2 :(得分:1)

f1和f2之间的差异可能是因为在f1中n> m的情况下编译器无法延迟diff和msk的评估。

下面是一个示例代码,用于计算我的计算机在VS2013下的功能和结果的微秒,同样,正如@George所说,还有多余的评估,所以我添加了f1b和f3。

f1 = 98201.7us
f1b = 95574.1us
f2 = 96613.1us
f3 = 94809.9us

代码:

#include <iostream>
#include <vector>
#include <random>
#include <limits>
#include <chrono>
#include <algorithm>

#define NOMINMAX
#include <windows.h>

struct HighResClock {
    typedef long long                               rep;
    typedef std::nano                               period;
    typedef std::chrono::duration<rep, period>      duration;
    typedef std::chrono::time_point<HighResClock>   time_point;
    static const bool is_steady = true;

    static time_point now( );
};
namespace {
    const long long g_Frequency = [] ( ) -> long long {
        LARGE_INTEGER frequency;
        QueryPerformanceFrequency( &frequency );
        return frequency.QuadPart;
    }( );
}

HighResClock::time_point HighResClock::now( ) {
    LARGE_INTEGER count;
    QueryPerformanceCounter( &count );
    return time_point( duration( count.QuadPart * static_cast<rep>( period::den ) / g_Frequency ) );
}

template <typename Type, Type Mask>
inline bool function1( const Type n, const Type m ) {
    static_assert( std::is_unsigned<Type>::value, "Type must be unsigned" );
    const Type diff = m - n;
    const Type msk = Mask & diff;
    return ( n <= m ) && ( ( !msk && !diff ) || ( msk && msk <= diff ) );
}

template <typename Type, Type Mask>
inline bool function1b( const Type n, const Type m ) {
    static_assert( std::is_unsigned<Type>::value, "Type must be unsigned" );
    if ( n > m )
        return false;
    const Type diff = m - n;
    const Type msk = Mask & diff;
    return ( ( !msk && !diff ) || ( msk && msk <= diff ) );
}

template <typename Type, Type Mask>
inline bool function2( const Type n, const Type m ) {
    static_assert( std::is_unsigned<Type>::value, "Type must be unsigned" );
    return ( n <= m ) && ( ( !( Mask & ( m - n ) ) && !( m - n ) ) || ( ( Mask & ( m - n ) ) && ( Mask & ( m - n ) ) <= ( m - n ) ) );
}

template <typename Type, Type Mask>
inline bool function3( const Type n, const Type m ) {
    static_assert( std::is_unsigned<Type>::value, "Type must be unsigned" );
    if ( n == m )
        return true;
    if ( n>m )
        return false;
    const Type diff = m - n;
    const Type msk = Mask & diff;
    return msk && msk <= diff;
}

std::vector<std::pair<size_t, size_t>> fill( size_t n ) {
    std::random_device rd;
    std::mt19937 gen( rd( ) );
    std::uniform_int_distribution<size_t> dis( 0, std::numeric_limits<size_t>::max( ) );
    auto rnd = [ &] { return dis( gen ); };

    std::vector<std::pair<size_t, size_t>> result;
    result.reserve( n );
    while ( n-- ) {
        result.push_back( { rnd( ), rnd( ) } );
    }
    return result;
}

size_t ignoreOptim {};
template <typename F>
std::chrono::microseconds foo( std::vector<std::pair<size_t, size_t>>  const  nms, F &&f ) {
    using clock = HighResClock; // Does VS2014 will fix the high_resolution_clock fallbacking to system_clock ???

    auto t0 = clock::now( );
    auto f1 = std::count_if( begin( nms ), end( nms ), std::forward<F&&>( f ) );
    auto t1 = clock::now( );
    ignoreOptim += f1;

    auto result = std::chrono::duration_cast<std::chrono::microseconds>( t1 - t0 );
    return result;
}

template <typename F> 
void bar( std::vector<std::pair<size_t, size_t>>  const  nms, char const* name, F &&f ) {
    std::chrono::microseconds f1 {};
    for ( int i {}; i != 100; ++i )
        f1 += foo( nms, std::forward<F&&>( f ) );
    std::cout << name << " = " << float( f1.count( ) ) / 10.f << "us" << std::endl;
}
int main( ) {
    auto nms = fill( 1 << 21 );
    bar( nms, "f1", [] ( std::pair<size_t, size_t> nm ) { return function1<size_t, 0x0003000000000000ull>( nm.first, nm.second ); } );
    bar( nms, "f1b", [] ( std::pair<size_t, size_t> nm ) { return function1b<size_t, 0x0003000000000000ull>( nm.first, nm.second ); } );

    bar( nms, "f2", [] ( std::pair<size_t, size_t> nm ) { return function2<size_t, 0x0003000000000000ull>( nm.first, nm.second ); } );
    bar( nms, "f3", [] ( std::pair<size_t, size_t> nm ) { return function3<size_t, 0x0003000000000000ull>( nm.first, nm.second ); } );

    return 0;
}