我正在研究静态多维数组收缩框架,我遇到了一个有点难以解释的问题,但我会尽我所能。假设我们有一个N
维数组类
template<typename T, int ... dims>
class Array {}
可以实例化为
Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on
现在我们有另一个名为Indices
template<int ... Idx>
struct Indices {}
我现在有一个函数contraction
,其签名应如下所示
template<T, int ... Dims, int ... Idx,
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>>
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)
我可能没有在这里获得语法,但我基本上希望返回的Array
具有基于Indices
条目的维度。让我举一个contraction
可以执行的示例。请注意,在此上下文中,收缩表示删除索引列表中的参数相等的维。
auto arr = contraction(Indices<0,0>, Array<double,3,3>)
// arr is Array<double> as both indices contract 0==0
auto arr = contraction(Indices<0,1>, Array<double,3,3>)
// arr is Array<double,3,3> as no contraction happens here, 0!=1
auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>)
// arr is Array<double,4> as 1st and 3rd indices contract 0==0
auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>)
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract
auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments
// requested but dimensions don't match 5!=6
// The parameters of Indices really do not matter as long as
// we can identify contractions. They are typically expressed as enums, I,J,K...
基本上,鉴于Idx...
和Dims...
都应该具有相同的大小,请检查Idx...
中的哪些值相等,获取它们出现的位置并删除相应的条目(职位)Dims...
。这基本上是tensor contraction rule。
阵列收缩规则:
sizeof...(Idx)==sizeof...(Dims)
Idx
和Dims
之间存在一对一对应关系,即如果我们有Indices<0,1,2>
和Array<double,4,5,6>
,{{1} }映射到0
,4
映射到1
,5
映射到2
。 6
中存在相同/相等的值,则表示收缩,这意味着Idx
中的相应维度应该消失,例如,如果我们有Dims
和{{1然后Indices<0,0,3>
以及这些值映射到Array<double,4,4,6>
和0==0
的相应维度都需要消失,结果数组应该是4
4
具有相同的值,但相应的Array<double,6>
不匹配,则应触发编译时错误,例如Idx
和{{1} } Dims
是不可能的,同样Indices<0,0,3>
不可能作为Array<double,4,5,6>
,这导致4!=5
无法以任何方式签约。Indices<0,1,0>
也匹配,4!=6
就允许多对,三元组,四元组等,例如Array<double,4,5,6>
将与Idx
签约,给定输入数组为Dims
。我对metaprogamming的了解并没有达到这个目的,但是我希望我已经明确了这个意图,因为有人可以指导我朝着正确的方向前进。
答案 0 :(得分:3)
进行实际检查的一堆constexpr
函数:
// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
return cur == N ? true :
(cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}
// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
return cur == N ? true :
(ind[cur] != index || dim[cur] == dimension) ?
check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}
// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
return is_uniq(ind, i) ? dim[i] :
check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}
现在我们需要一种摆脱-1
s:
template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
: concat<Indices<I1..., I2...>, Inds...> {};
// filter out all instances of I from Is...,
// return the rest as an Indices
template<int I, int... Is>
struct filter
: concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};
使用它们:
template<class Ind, class Arr, class Seq>
struct contraction_impl;
template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
static constexpr int ind[] = { Ind... };
static constexpr int dim[] = { Dim... };
static constexpr int result[] = {calc(Seq, ind, dim)...};
template<int... Dims>
static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;
using type = decltype(unpack_helper(typename filter<-1, result[Seq]...>::type{}));
};
template<class T, int ... Dims, int ... Idx,
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>,
std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);
除make_index_sequence
之外的所有内容都是C ++ 11。你可以在SO上找到大量的实现。
答案 1 :(得分:1)
这是一团糟,但我认为它可以做你想做的事。几乎可以肯定有许多简化可以做到这一点,但这是我通过测试的第一次通过。请注意,这不会实现收缩,而只是确定类型应该是什么。如果那不是您所需要的,我会提前道歉。
#include <type_traits>
template <std::size_t...>
struct Indices {};
template <typename, std::size_t...>
struct Array {};
// Count number of 'i' in 'rest...', base case
template <std::size_t i, std::size_t... rest>
struct Count : std::integral_constant<std::size_t, 0>
{};
// Count number of 'i' in 'rest...', inductive case
template <std::size_t i, std::size_t j, std::size_t... rest>
struct Count<i, j, rest...> :
std::integral_constant<std::size_t,
Count<i, rest...>::value + ((i == j) ? 1 : 0)>
{};
// Is 'i' contained in 'rest...'?
template <std::size_t i, std::size_t... rest>
struct Contains :
std::integral_constant<bool, (Count<i, rest...>::value > 0)>
{};
// Accumulation of counts of indices in all, base case
template <typename All, typename Remainder,
typename AccIdx, typename AccCount>
struct Counts {
using indices = AccIdx;
using counts = AccCount;
};
// Accumulation of counts of indices in all, inductive case
template <std::size_t... all, std::size_t i, std::size_t... rest,
std::size_t... indices, std::size_t... counts>
struct Counts<Indices<all...>, Indices<i, rest...>,
Indices<indices...>, Indices<counts...>>
: std::conditional<Contains<i, indices...>::value,
Counts<Indices<all...>, Indices<rest...>,
Indices<indices...>,
Indices<counts...>>,
Counts<Indices<all...>, Indices<rest...>,
Indices<indices..., i>,
Indices<counts...,
Count<i, all...>::value>>>::type
{};
// Get value in From that matched the first value of Idx that matched idx
template <std::size_t idx, typename Idx, typename From>
struct First : std::integral_constant<std::size_t, 0>
{};
template <std::size_t i, std::size_t j, std::size_t k,
std::size_t... indices, std::size_t... values>
struct First<i, Indices<j, indices...>, Indices<k, values...>>
: std::conditional<i == j,
std::integral_constant<std::size_t, k>,
First<i, Indices<indices...>,
Indices<values...>>>::type
{};
// Return whether all values in From that match Idx being idx are tgt
template <std::size_t idx, std::size_t tgt, typename Idx, typename From>
struct AllMatchTarget : std::true_type
{};
template <std::size_t idx, std::size_t tgt,
std::size_t i, std::size_t j,
std::size_t... indices, std::size_t... values>
struct AllMatchTarget<idx, tgt,
Indices<i, indices...>, Indices<j, values...>>
: std::conditional<i == idx && j != tgt, std::false_type,
AllMatchTarget<idx, tgt, Indices<indices...>,
Indices<values...>>>::type
{};
/* Generate the dimensions, given the counts, indices, and values */
template <typename Counts, typename Indices,
typename AllIndices, typename Values, typename Accum>
struct GenDims;
template <typename A, typename V, typename R>
struct GenDims<Indices<>, Indices<>, A, V, R> {
using type = R;
};
template <typename T, std::size_t i, std::size_t c,
std::size_t... counts, std::size_t... indices,
std::size_t... dims, typename AllIndices, typename Values>
struct GenDims<Indices<c, counts...>, Indices<i, indices...>,
AllIndices, Values, Array<T, dims...>>
{
static constexpr auto value = First<i, AllIndices, Values>::value;
static_assert(AllMatchTarget<i, value, AllIndices, Values>::value,
"Index doesn't correspond to matching dimensions");
using type = typename GenDims<
Indices<counts...>, Indices<indices...>,
AllIndices, Values,
typename std::conditional<c == 1,
Array<T, dims..., value>,
Array<T, dims...>>::type>::type;
};
/* Put it all together */
template <typename I, typename A>
struct ContractionType;
template <typename T, std::size_t... indices, std::size_t... values>
struct ContractionType<Indices<indices...>, Array<T, values...>> {
static_assert(sizeof...(indices) == sizeof...(values),
"Number of indices and dimensions do not match");
using counts = Counts<Indices<indices...>,
Indices<indices...>,
Indices<>, Indices<>>;
using type = typename GenDims<typename counts::counts,
typename counts::indices,
Indices<indices...>, Indices<values...>,
Array<T>>::type;
};
static_assert(std::is_same<typename
ContractionType<Indices<0, 0>, Array<double, 3, 3>>::type,
Array<double>>::value, "");
static_assert(std::is_same<typename
ContractionType<Indices<0, 1>, Array<double, 3, 3>>::type,
Array<double, 3, 3>>::value, "");
static_assert(std::is_same<typename
ContractionType<Indices<0, 1, 0>, Array<double, 3, 4, 3>>::type,
Array<double, 4>>::value, "");
static_assert(std::is_same<typename
ContractionType<Indices<0, 1, 0, 7, 7, 2>,
Array<double, 3, 4, 3, 5, 5, 6>>::type,
Array<double, 4, 6>>::value, "");
// Errors appropriately when uncommented
/* static_assert(std::is_same<typename */
/* ContractionType<Indices<10,10, 2, 3>, */
/* Array<double, 5,6,4,4>>::type, */
/* Array<double>::value, ""); */
以下是对此处发生的事情的解释:
Counts
生成唯一索引列表(Counts::indices
)以及每个索引在序列中显示的次数(Counts::counts
)。Counts
的对,并且对于每个索引,如果计数为1,我累积值并递归。否则,我传递累计值并递归。最令人恼火的部分是static_assert
中的GenDims
,它会验证所有匹配尺寸相同的索引。