根据另一个可变参数包查找可变参数包的收缩

时间:2016-05-17 12:58:07

标签: c++ c++11 multidimensional-array template-meta-programming

我正在研究静态多维数组收缩框架,我遇到了一个有点难以解释的问题,但我会尽我所能。假设我们有一个N维数组类

template<typename T, int ... dims>
class Array {}

可以实例化为

Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on

现在我们有另一个名为Indices

的类
template<int ... Idx>
struct Indices {}

我现在有一个函数contraction,其签名应如下所示

template<T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>> 
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)

我可能没有在这里获得语法,但我基本上希望返回的Array具有基于Indices条目的维度。让我举一个contraction可以执行的示例。请注意,在此上下文中,收缩表示删除索引列表中的参数相等的维

auto arr = contraction(Indices<0,0>, Array<double,3,3>) 
// arr is Array<double> as both indices contract 0==0

auto arr = contraction(Indices<0,1>, Array<double,3,3>) 
// arr is Array<double,3,3> as no contraction happens here, 0!=1

auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>) 
// arr is Array<double,4> as 1st and 3rd indices contract 0==0  

auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>) 
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract

auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments 
// requested but dimensions don't match 5!=6

// The parameters of Indices really do not matter as long as 
// we can identify contractions. They are typically expressed as enums, I,J,K...

基本上,鉴于Idx...Dims...都应该具有相同的大小,请检查Idx...中的哪些值相等,获取它们出现的位置并删除相应的条目(职位)Dims...。这基本上是tensor contraction rule

阵列收缩规则:

  1. 索引的参数数量和数组的维度/等级应该相同,即sizeof...(Idx)==sizeof...(Dims)
  2. IdxDims之间存在一对一对应关系,即如果我们有Indices<0,1,2>Array<double,4,5,6>,{{1} }映射到04映射到15映射到2
  3. 如果6中存在相同/相等的值,则表示收缩,这意味着Idx中的相应维度应该消失,例如,如果我们有Dims和{{1然后Indices<0,0,3>以及这些值映射到Array<double,4,4,6>0==0的相应维度都需要消失,结果数组应该是4
  4. 如果4具有相同的值,但相应的Array<double,6>不匹配,则应触发编译时错误,例如Idx和{{1} } Dims是不可能的,同样Indices<0,0,3>不可能作为Array<double,4,5,6>,这导致
  5. 对于具有不同尺寸的阵列,不可能收缩,例如4!=5无法以任何方式签约。
  6. 只要相应的Indices<0,1,0>也匹配,4!=6就允许多对,三元组,四元组等,例如Array<double,4,5,6>将与Idx签约,给定输入数组为Dims
  7. 我对metaprogamming的了解并没有达到这个目的,但是我希望我已经明确了这个意图,因为有人可以指导我朝着正确的方向前进。

2 个答案:

答案 0 :(得分:3)

进行实际检查的一堆constexpr函数:

// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
    return cur == N ? true : 
           (cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}

// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
                            const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
    return cur == N ? true :
           (ind[cur] != index || dim[cur] == dimension) ? 
                check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}

// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
    return is_uniq(ind, i) ? dim[i] :
           check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}

现在我们需要一种摆脱-1 s:

的方法
template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
    :  concat<Indices<I1..., I2...>, Inds...> {};

// filter out all instances of I from Is...,
// return the rest as an Indices    
template<int I, int... Is>
struct filter
    :  concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};

使用它们:

template<class Ind, class Arr, class Seq>
struct contraction_impl;

template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
    static constexpr int ind[] = { Ind... };
    static constexpr int dim[] = { Dim... };
    static constexpr int result[] = {calc(Seq, ind, dim)...};

    template<int... Dims>
    static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;

    using type = decltype(unpack_helper(typename filter<-1,  result[Seq]...>::type{}));
};


template<class T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>, 
                          std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);

make_index_sequence之外的所有内容都是C ++ 11。你可以在SO上找到大量的实现。

答案 1 :(得分:1)

这是一团糟,但我认为它可以做你想做的事。几乎可以肯定有许多简化可以做到这一点,但这是我通过测试的第一次通过。请注意,这不会实现收缩,而只是确定类型应该是什么。如果那不是您所需要的,我会提前道歉。

#include <type_traits>

template <std::size_t...>
struct Indices {};

template <typename, std::size_t...>
struct Array {};

// Count number of 'i' in 'rest...', base case
template <std::size_t i, std::size_t... rest>
struct Count : std::integral_constant<std::size_t, 0>
{};

// Count number of 'i' in 'rest...', inductive case
template <std::size_t i, std::size_t j, std::size_t... rest>
struct Count<i, j, rest...> :
    std::integral_constant<std::size_t,
                           Count<i, rest...>::value + ((i == j) ? 1 : 0)>
{};

// Is 'i' contained in 'rest...'?
template <std::size_t i, std::size_t... rest>
struct Contains :
    std::integral_constant<bool, (Count<i, rest...>::value > 0)>
{};


// Accumulation of counts of indices in all, base case
template <typename All, typename Remainder,
          typename AccIdx, typename AccCount>
struct Counts {
    using indices = AccIdx;
    using counts = AccCount;
};

// Accumulation of counts of indices in all, inductive case
template <std::size_t... all, std::size_t i, std::size_t... rest,
          std::size_t... indices, std::size_t... counts>
struct Counts<Indices<all...>, Indices<i, rest...>,
              Indices<indices...>, Indices<counts...>>
    : std::conditional<Contains<i, indices...>::value,
                       Counts<Indices<all...>, Indices<rest...>,
                              Indices<indices...>,
                              Indices<counts...>>,
                       Counts<Indices<all...>, Indices<rest...>,
                              Indices<indices..., i>,
                              Indices<counts...,
                                      Count<i, all...>::value>>>::type
{};

// Get value in From that matched the first value of Idx that matched idx
template <std::size_t idx, typename Idx, typename From>
struct First : std::integral_constant<std::size_t, 0>
{};
template <std::size_t i, std::size_t j, std::size_t k,
          std::size_t... indices, std::size_t... values>
struct First<i, Indices<j, indices...>, Indices<k, values...>>
    : std::conditional<i == j,
                       std::integral_constant<std::size_t, k>,
                       First<i, Indices<indices...>,
                             Indices<values...>>>::type
{};

// Return whether all values in From that match Idx being idx are tgt
template <std::size_t idx, std::size_t tgt, typename Idx, typename From>
struct AllMatchTarget : std::true_type
{};
template <std::size_t idx, std::size_t tgt,
          std::size_t i, std::size_t j,
          std::size_t... indices, std::size_t... values>
struct AllMatchTarget<idx, tgt,
                      Indices<i, indices...>, Indices<j, values...>>
    : std::conditional<i == idx && j != tgt, std::false_type,
                       AllMatchTarget<idx, tgt, Indices<indices...>,
                                      Indices<values...>>>::type
{};

/* Generate the dimensions, given the counts, indices, and values */
template <typename Counts, typename Indices,
          typename AllIndices, typename Values, typename Accum>
struct GenDims;

template <typename A, typename V, typename R>
struct GenDims<Indices<>, Indices<>, A, V, R> {
    using type = R;
};
template <typename T, std::size_t i, std::size_t c,
          std::size_t... counts, std::size_t... indices,
          std::size_t... dims, typename AllIndices, typename Values>
struct GenDims<Indices<c, counts...>, Indices<i, indices...>,
               AllIndices, Values, Array<T, dims...>>
{
    static constexpr auto value = First<i, AllIndices, Values>::value;
    static_assert(AllMatchTarget<i, value, AllIndices, Values>::value,
                  "Index doesn't correspond to matching dimensions");
    using type = typename GenDims<
        Indices<counts...>, Indices<indices...>,
        AllIndices, Values,
        typename std::conditional<c == 1,
                                  Array<T, dims..., value>,
                                  Array<T, dims...>>::type>::type;
};

/* Put it all together */
template <typename I, typename A>
struct ContractionType;

template <typename T, std::size_t... indices, std::size_t... values>
struct ContractionType<Indices<indices...>, Array<T, values...>> {
    static_assert(sizeof...(indices) == sizeof...(values),
                   "Number of indices and dimensions do not match");
    using counts = Counts<Indices<indices...>,
                          Indices<indices...>,
                          Indices<>, Indices<>>;
    using type = typename GenDims<typename counts::counts,
                                  typename counts::indices,
                                  Indices<indices...>, Indices<values...>,
                                  Array<T>>::type;
};

static_assert(std::is_same<typename
              ContractionType<Indices<0, 0>, Array<double, 3, 3>>::type,
              Array<double>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1>, Array<double, 3, 3>>::type,
              Array<double, 3, 3>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1, 0>, Array<double, 3, 4, 3>>::type,
              Array<double, 4>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1, 0, 7, 7, 2>,
              Array<double, 3, 4, 3, 5, 5, 6>>::type,
              Array<double, 4, 6>>::value, "");

// Errors appropriately when uncommented
/* static_assert(std::is_same<typename */
/*               ContractionType<Indices<10,10, 2, 3>, */
/*               Array<double, 5,6,4,4>>::type, */
/*               Array<double>::value, ""); */

以下是对此处发生的事情的解释:

  • 首先,我使用Counts生成唯一索引列表(Counts::indices)以及每个索引在序列中显示的次数(Counts::counts)。
  • 然后我走过索引,计算来自Counts的对,并且对于每个索引,如果计数为1,我累积值并递归。否则,我传递累计值并递归。

最令人恼火的部分是static_assert中的GenDims,它会验证所有匹配尺寸相同的索引。