我想在C ++中实现张量类,其中张量维度作为类的模板参数给出。琐碎的,它看起来像这样:
template <size_t... TDims>
class Tensor
{/*...*/};
问题是类型Tensor<2, 1>
和Tensor<2>
应该是相同的,即使它们不是。例如,我希望Tensor<3, 7>
和Tensor<7>
的矩阵乘法返回Tensor<3>
,而不是(仅)Tensor<3, 1>
。
解决此问题的一种方法是实现一种元模板函数,从TDims
的末尾删除所有元函数(NumberSequence
只是包含多个数字常量的单一类型):
namespace detail {
template <NumberSequence TDimSequence>
class Tensor
{/*...*/};
}
template <size_t... TDims>
using Tensor = detail::Tensor<typename RemoveOnesFromEnd<TDims...>::type>;
这使得Tensor<2>
和Tensor <2, 1>
实际上是同一类型。我成功地实现了struct RemoveOnesFromEnd
来完成它应该做的事情,但这又开启了另一个问题:将Tensor
作为函数参数现在可以防止扣除模板参数。例如,使用此矩阵乘法,当给定两个张量时,编译器无法推导出三个size_t
值:
template <size_t TRows1, size_t TCols1Rows2, size_t TCols2>
Tensor<TRows1, TCols2> multiply(Tensor<TRows1, TCols1Rows2> t1, Tensor<TCols1Rows2, TCols2> t2)
{/*...*/}
理论上,明确定义size_t
参数应具有的值,但由于RemoveOnesFromEnd
添加的间接层次,编译器无法处理它。
Tensor<2, 1>
和Tensor<2>
成为完全相同的类型,同时允许成功减少模板参数?另一种解决方案是让Tensor<..., 1>
从Tensor<...>
继承,直到最后没有1,但从我可以告诉它将遇到与模板推导相同的问题,以前的解决方案会,因为我必须有一个元模板函数来反转张量的维度,以便我可以检查最后一个维度是否为1(因为我只能检查之前出现的参数< / em>一个参数包),然后会添加相同的问题级别的间接。事实上,我也在RemoveOnesFromEnd
中使用了逆转技术。