C ++ 11中张量的递归索引

时间:2018-07-27 18:16:05

标签: c++11 recursion indexing

我有一个等级N的张量类,它们包装存储在数组中的数据。例如,等级3张量将具有维度(d0,d1,d2),并且将使用长度为d0 * d1 * d2的基础数组中的多索引(i0,i1,i2)访问唯一元素。如果d0 = d1 = d2 = 10,i0 = 1,i1 = 2,i2 = 3,则将访问数组的元素123。

我实现了一个递归定义的类,该类从多索引计算单个数组索引,如下所示:

template<size_t N>
class TensorIndex : TensorIndex<N-1> {
private:
  size_t d;
public:
template<typename...Ds>
TensorIndex( size_t d0, Ds...ds ) : TensorIndex<N-1>( ds... ), d(d0) {}
  template<typename...Is>
  size_t index( size_t i0, Is...is ) {
    return i0+d*TensorIndex<N-1>::index(is...);
  }
};

template<>
struct TensorIndex<1> {
TensorIndex( size_t ) {}
  size_t index( size_t i ) { return i; }
};

哪个顺序颠倒了。

TensorIndex<3> g(10,10,10);
std::cout << g.index(1,2,3) << std::endl;

输出321.颠倒构造函数和索引函数的参数顺序的简单方法是什么?

编辑: 我尝试使用建议的方法来实现可变参数的反转,但这不是最佳选择,因为它需要同时反转索引和构造函数的参数,并且这两种情况下必要的辅助函数会略有不同。初始化程序列表答案看起来更简单。

1 个答案:

答案 0 :(得分:1)

无需递归或反转,可以使用initializer-list调用评估函数,该评估函数从左到右累积索引。初始化列表中调用的函数对象应具有非无效的返回类型:

#include <cstddef>
#include <iostream>

using namespace std;

template<size_t N>
class TensorIndex {
public:
    template<typename... Args>
    TensorIndex(Args... args) : dims{static_cast<size_t>(args)...}
    {
        static_assert(sizeof...(Args) == N,
                      "incorrect number of arguments for TensorIndex constructor");
    }

    template<typename... Args>
    size_t index(Args... args) {
        static_assert(sizeof...(Args) == N,
                      "incorrect number of arguments for TensorIndex::index()");
        IndexEval eval{dims};
        Pass pass{eval(args)...}; // evaluate from left to right : initializer-list                                                                           
        return eval.get_res();
    }

private:
    const size_t dims[N];

    class IndexEval {
        size_t k = 0;
        size_t res = 0;
        const size_t* dims;
    public:
        IndexEval(const size_t* dims) : dims{dims} {}
        size_t operator()(size_t i) {
            return res = res * dims[k++] + i;
        }
        size_t get_res() const { return res; }
    };

    struct Pass {
        template<typename... Args> Pass(Args...) {}
    };
};

int main()
{
    TensorIndex<3> g(10, 10, 10);
    cout << g.index(1, 2, 3) << endl;
}