如何简洁地用多个参数初始化模板函数?

时间:2017-08-11 03:20:14

标签: c++ templates

我有一个包含许多模板参数的模板函数,它看起来像这样:

template <int a, int b, int c>
void foo(){
    for(int i = 0 ; i < a ; i ++)
       for(int j = 0 ; j < b ; j ++)
           for(int k = 0 ; k < c ; k ++)
               //do something
}

在这种情况下,使用模板参数(a b c),以便编译器可以展开这些循环。但实际上,这些参数可能需要很多价值。比如说,如果a / b / c中的每一个都可以从[1,2,3,4]中获取一个值,那么你总共需要64个模板函数进行初始化。所以,你的代码就像:

if(a == 1 && b == 1 && c == 1) foo<1,1,1>();
else if(a == 1 && b == 1 &&c == 2) foo<1,1,2>();
//......
else foo<4,4,4>();

这是一个很可怕的编码。那么,你有什么简洁的方法吗?

3 个答案:

答案 0 :(得分:1)

通过使用嵌套模板函数来确定每个模板参数,您可以消除大量重复,一次一个。这是一种破解,如果范围相当大,那么您可能希望使用代码生成来创建这些包装器。它不是非常理想的,但是用四个案例编写三个函数肯定比写出4个 3 条件更好。

void foo0(int a, int b, int c)
{
    switch (a) {
        case 1: foo1<1>(b, c); break;
        case 2: foo1<2>(b, c); break;
        case 3: foo1<3>(b, c); break;
        case 4: foo1<4>(b, c); break;
    }
}

template <int a>
void foo1(int b, int c)
{
    switch (b) {
        case 1: foo2<a, 1>(c); break;
        case 2: foo2<a, 2>(c); break;
        case 3: foo2<a, 3>(c); break;
        case 4: foo2<a, 4>(c); break;
    }
}

template <int a, int b>
void foo2(int c)
{
    switch (c) {
        case 1: foo<a, b, 1>(); break;
        case 2: foo<a, b, 2>(); break;
        case 3: foo<a, b, 3>(); break;
        case 4: foo<a, b, 4>(); break;
    }
}

您可以将整个任务委派给编译器,但代码明显长于这一个代码段,因此如果范围大于[1,5],则应该只跟踪该路径。以下是此方法的an example。它的冗长和kludgy(并且可能由更聪明的人简化)但它将完全在编译时生成决策树。请注意,树不是最佳的,但在您的情况下可能会或可能不重要。

#include <iostream>

template <int a, int b, int c>
void foo()
{
    std::cout << "foo<" << a << "," << b << "," << c << ">()\n";
}

// Class wrapper so that we can generically apply this function.
template <int a, int b, int c>
struct foo_wrapper
{
    void operator()() { foo<a, b, c>(); }
};

template <template <int, int, int> class fn, int min, int max, int a, int b, int i>
struct caller_p3
{
    static void call(int c) {
        if (c == i) {
            fn<a, b, i>()();
        } else {
            caller_p3<fn, min, max, a, b, i + 1>::call(c);
        }
    }
};

template <template <int, int, int> class fn, int min, int max, int a, int b>
struct caller_p3<fn, min, max, a, b, max>
{
    static void call(int c) {
        if (c == max) {
            fn<a, b, max>()();
        } else {
            // out of range, throw?
        }
    }
};

template <template <int, int, int> class fn, int min, int max, int a, int i>
struct caller_p2
{
    static void call(int b, int c) {
        if (b == i) {
            caller_p3<fn, min, max, a, i, min>::call(c);
        } else {
            caller_p2<fn, min, max, a, i + 1>::call(b, c);
        }
    }
};

template <template <int, int, int> class fn, int min, int max, int a>
struct caller_p2<fn, min, max, a, max>
{
    static void call(int b, int c) {
        if (b == max) {
            caller_p3<fn, min, max, a, max, min>::call(c);
        } else {
            // out of range, throw?
        }
    }
};

template <template <int, int, int> class fn, int min, int max, int i>
struct caller_p1
{
    static void call(int a, int b, int c) {
        if (a == i) {
            caller_p2<fn, min, max, i, min>::call(b, c);
        } else {
            caller_p1<fn, min, max, i + 1>::call(a, b, c);
        }
    }
};

template <template <int, int, int> class fn, int min, int max>
struct caller_p1<fn, min, max, max>
{
    static void call(int a, int b, int c) {
        if (a == max) {
            caller_p2<fn, min, max, max, min>::call(b, c);
        } else {
            // out of range, throw?
        }
    }
};

// Generic caller.
template <template <int, int, int> class fn, int min, int max>
struct caller
{
    void operator()(int a, int b, int c) {
        caller_p1<fn, min, max, min>::call(a, b, c);
    }
};

int main() {
    caller<foo_wrapper, 0, 5>()(1, 2, 3);
    caller<foo_wrapper, 0, 5>()(0, 0, 5);
    caller<foo_wrapper, 0, 5>()(5, 1, 0);
}

Here is a C++11 implementation使用可变参数模板允许foo()的任意数量的模板参数:

#include <iostream>

template <int a, int b, int c>
void foo()
{
    std::cout << "foo<" << a << "," << b << "," << c << ">()\n";
}

// Class wrapper so that we can generically apply this function.
template <int a, int b, int c>
struct foo_wrapper
{
    void operator()() { foo<a, b, c>(); }
};

// Caller implementation.
template <typename T, template <T...> class fn, T min, T max, T i, T... parms>
struct caller_impl
{
    template <typename... Tail>
    static void call(T head, Tail... tail)
    {
        if (head == i) {
            caller_impl<T, fn, min, max, min, parms..., i>::call(tail...);
        } else {
            caller_impl<T, fn, min, max, i + 1, parms...>::call(head, tail...);
        }
    }

    static void call()
    {
        fn<parms...>()();
    }
};

// Specialization for i==max
template <typename T, template <T...> class fn, T min, T max, T... parms>
struct caller_impl<T, fn, min, max, max, parms...>
{
    template <typename... Tail>
    static void call(T head, Tail... tail)
    {
        if (head == max) {
            caller_impl<T, fn, min, max, min, parms..., max>::call(tail...);
        } else {
            // Out of range, throw?
        }
    }

    static void call()
    {
        fn<parms...>()();
    }
};

// Helper to kick off the call.
template <typename T, template <T...> class fn, T min, T max, typename... parms>
void caller(parms... p)
{
    caller_impl<T, fn, min, max, min>::call(p...);
}

int main() {
    caller<int, foo_wrapper, 0, 5>(1, 2, 3);
    caller<int, foo_wrapper, 0, 5>(0, 0, 5);
    caller<int, foo_wrapper, 0, 5>(5, 1, 0);
}

答案 1 :(得分:1)

您可以创建一系列函数:

template <int a, int b, int c> void foo() { 
    // Your implementation
}

// Helper function which does the dispatch
template <std::size_t ... Is>
void foo(int index, std::index_sequence<Is...>)
{
    using f_type = void();
    f_type* f[] = {&foo<1 + Is / 16, 1 + (Is / 4) % 4, 1 + Is % 4>...};

    f[index]();
}

// function which call f<a, b, c>()
// a, b, c should be in [1;4]
void foo(int a, int b, int c)
{
    foo((a - 1) * 16 + (b - 1) * 4 + c - 1, std::make_index_sequence<64>());   
}

Demo

答案 2 :(得分:0)

首先,我必须问为什么你需要在某些情况下显式实例化你的函数 - 通常编译器可以自己解决这个问题。例如,这只是起作用:

template<int a, int b, int c>
void foo()
{
...
}

int main(const int argc, const char* argv[]) {
  foo<2,4,1>();
  ...
}

也许你需要在编译单元中使用foo()而不是定义它的那个。在这种情况下,您需要为计划使用的每个案例进行显式实例化。对于(1,2,3,4)中的[a,b,c],这就是诀窍:

#define _FWD_DECLc(a,b,c) foo<a, b, c>();
#define _FWD_DECLb(a,b)   _FWD_DECLc(a,b,1)  _FWD_DECLc(a,b,2) _FWD_DECLc(a,b,3) _FWD_DECLc(a,b,4)
#define _FWD_DECLa(a)     _FWD_DECLb(a,1)    _FWD_DECLb(a,2)   _FWD_DECLb(a,3)   _FWD_DECLb(a,4)
#define _FWD_DECL         _FWD_DECLa(1)      _FWD_DECLa(2)     _FWD_DECLa(3)     _FWD_DECLa(4)

void blah() {
  _FWD_DECL
}

这个想法是永远不会调用blah()但是为了编译它,编译器必须为([1,2,3,4]的所有组合实例化foo(), [1,2,3,4],[1,2,3,4])。当然,对于三个以上的模板参数,如果[a,b,c]远远高于4,这将会变得难以处理,但对于这种情况,它非常简洁。