我有Vector(CVector<T, std::size_t Size>
),Matrix(CMatrix<T, std::size_t Height, std::size_t Width>
)和Tensor(CTensor<T, std::size_t... Sizes>
)类,我希望能够从CTensor
类隐式转换如果是CVector
那么sizeof...(Sizes) == 1
课程和CMatrix
sizeof...(Sizes) == 2
课程,那么我有以下转化运算符(最初我没有std::enable_if
模板参数希望我可以使用SFINAE来防止它编译):
template <typename std::enable_if<sizeof...(Sizes) == 2, int>::type = 0>
operator CMatrix<NumType, Sizes...>() const
{
static_assert(sizeof...(Sizes) == 2, "You can only convert a rank 2 tensor to a matrix");
CMatrix<NumType, Sizes...> matResult;
auto& arrThis = m_numArray;
auto& arrResult = matResult.m_numArray;
concurrency::parallel_for_each( arrResult.extent, [=, &arrThis, &arrResult]( concurrency::index<2> index ) restrict( amp ) {
arrResult[index] = arrThis[index];
} );
return matResult;
}
template <typename std::enable_if<sizeof...(Sizes) == 1, int>::type = 0>
operator CVector<NumType, Sizes...>() const
{
static_assert(sizeof...(Sizes) == 1, "You can only convert a rank 1 tensor to a vector");
CVector<NumType, Sizes...> vecResult;
auto& arrThis = m_numArray;
auto& arrResult = vecResult.m_numArray;
concurrency::parallel_for_each( arrResult.extent, [=, &arrThis, &arrResult]( concurrency::index<1> index ) restrict( amp ) {
arrResult[index] = arrThis[index];
} );
return vecResult;
}
但是,如果我实例化CTensor<float, 3, 3, 3>
并尝试编译,我将会收到错误,声明CMatrix
和CVector
的模板参数太多以及错误关于std::enable_if<false, int>
的缺失类型。有没有办法实现这些运算符而不必将CTensor
专门用于等级1和2?
答案 0 :(得分:3)
我简化了以前的解决方案,详情如下。
根本不需要SFINAE,因为模板方法中只有static_assert
,只有在使用时才会实例化。
我的解决方案使转换运算符成为具有依赖参数的模板方法(以便编译器不实例化其主体,仅解析签名),并添加假装缺失的-1
大小尺寸为1的张量(不是张量本身,而是提取参数包的辅助类),允许编译器实例化张量模板本身,但以后不允许在无效维度的张量内实例化转换运算符。 / p>
#include <cstddef>
template <typename T, unsigned int index, T In, T... args>
struct GetArg
{
static const T value = GetArg<T, index-1, args...>::value;
};
template <typename T, T In, T... args>
struct GetArg<T, 0, In, args...>
{
static const T value = In;
};
template <typename T, T In>
struct GetArg<T, 1, In>
{
static const T value = -1;
};
template <typename T, std::size_t Size>
struct CVector
{
};
template <typename T, std::size_t Height, std::size_t Width>
struct CMatrix
{
};
template <typename T, std::size_t... Sizes>
struct CTensor
{
template <std::size_t SZ = sizeof...(Sizes)>
operator CVector<T, GetArg<std::size_t, 0, Sizes...>::value>() const
{
static_assert(SZ == 1, "You can only convert a rank 1 tensor to a vector");
CVector<T, Sizes...> vecResult;
return vecResult;
}
template <std::size_t SZ = sizeof...(Sizes)>
operator CMatrix<T, GetArg<std::size_t, 0, Sizes...>::value, GetArg<std::size_t, 1, Sizes...>::value>() const
{
static_assert(SZ == 2, "You can only convert a rank 2 tensor to a matrix");
CMatrix<T, Sizes...> matResult;
return matResult;
}
};
int main()
{
CTensor<float, 3> tensor3;
CTensor<float, 3, 3> tensor3_3;
CTensor<float, 3, 3, 3> tensor3_3_3;
CVector<float, 3> vec(tensor3);
//CVector<float, 3> vec2(tensor3_3); // static_assert fails!
CMatrix<float, 3, 3> mat(tensor3_3);
//CMatrix<float, 3, 3> mat2(tensor3_3_3); // static_assert fails!
}
答案 1 :(得分:2)
以下是static_assert
:
template <typename NumType,size_t... Sizes>
struct CTensor {
template<size_t n,size_t m>
operator CMatrix<NumType,n,m>() const
{
static_assert(
sizeof...(Sizes)==2,
"You can only convert a rank 2 tensor to a matrix"
);
static_assert(
std::is_same<CTensor<NumType,n,m>,CTensor>::value,
"Size mismatch"
);
...
}
template<size_t n>
operator CVector<NumType,n>() const
{
static_assert(
sizeof...(Sizes)==1,
"You can only convert a rank 1 tensor to a vector"
);
static_assert(
std::is_same<CTensor<NumType,n>,CTensor>::value,
"Size mismatch"
);
...
}
};
或与SFINAE:
template <typename NumType,size_t... Sizes>
struct CTensor {
template<size_t n,size_t m,
typename =
typename std::enable_if<
std::is_same<CTensor<NumType,n,m>,CTensor>::value, int
>::type
>
operator CMatrix<NumType,n,m>() const
{
...
}
template<size_t n,
typename =
typename std::enable_if<
std::is_same<CTensor<NumType,n>,CTensor>::value, int
>::type
>
operator CVector<NumType,n>() const
{
...
}
};
这是使用函数重载的另一种方法:
template <typename NumType,size_t... Sizes>
struct CTensor {
template<size_t n,size_t m>
CMatrix<NumType,n,m> convert() const
{
...
}
template<size_t n>
CVector<NumType,n> convert() const
{
...
}
template <typename T>
operator T() const { return convert<Sizes...>(); }
};
答案 2 :(得分:1)
这实际上是对我的评论的更长描述:为什么不仅使用CTensor并将其别名为CVector / CMatrix?不需要转换,它们将变为相同。
......它以与标题要求完全不同的方式解决了真正的问题。仅供记录:)
1)隐藏名称空间detail
中的基础实现
2)专注于真正需要专业化的事项
(这可以通过一些辅助结构来完成 - 专门提供方法的结构)
3)将CVector / CMatrix混淆为CTensor (当时不需要操作员)
#include <vector>
namespace detail {
template<class T, std::size_t... Sizes>
class base;
template<class T, std::size_t Size>
class base<T, Size> {
std::vector<T> data;
public:
T& operator[](std::size_t i) {
return data[i]; }
};
template<class T, std::size_t First, std::size_t... More>
class base<T, First, More...> {
std::vector<base<T, More...>> data;
public:
// this could be done better, just an example
base<T, More...>& operator[](std::size_t i) {
return data[i]; }
};
}
template<class T, std::size_t... Sizes>
class CTensor: public detail::base<T, Sizes...> {};
//we can specialize CTensor<T, Size>
//and CTensor<T, Width, Height> here
template<class T, std::size_t Size>
using CVector = CTensor<T, Size>;
template<class T, std::size_t Width, std::size_t Height>
using CMatrix = CTensor<T, Width, Height>;
答案 3 :(得分:0)
使sizeof...(Sizes)
成为一个依赖参数,并使CMatrix
/ CVector
类型更正(采用正确数量的模板参数)。
使用:
template <std::size_t ... Is> struct index_sequence {};
template <std::size_t I, typename T> struct index_element;
template <std::size_t I, std::size_t ... Is>
struct index_element<I, index_sequence<Is...> >
{
private:
static constexpr const std::size_t a[] = {Is...};
public:
static_assert(I < sizeof...(Is), "out of bound");
static constexpr const std::size_t value = a[I];
};
然后你可以这样做:
template <
std::size_t N = sizeof...(Sizes),
typename std::enable_if<N == 1, int>::type = 0>
operator CVector<
T,
index_element<0, index_sequence<Sizes..., 0>
>::value>() const
{
// Your implementation
}
template <
std::size_t N = sizeof...(Sizes),
typename std::enable_if<N == 2, int>::type = 0>
operator CMatrix<
T,
index_element<0, index_sequence<Sizes..., 0>>::value
index_element<1, index_sequence<Sizes..., 0, 0>>::value
>() const
{
// Your implementation
}