基本和特定非基本类型的类模板运算符重载

时间:2016-08-26 13:42:10

标签: c++ templates cuda operator-overloading

我只是在写一个MathVector类

template<typename T> MathVector
{
   using value_type = T;

   // further implementation
};

但是,该类被认为可以使用基本类型,但也可以使用复合类

template<typename T> Complex
{
   using value_type = T;

   // further implementation
};

提供例如成员函数

template<typename T> Complex<T>& Complex<T>::operator*=(const Complex<T>& c);
template<typename T> Complex<T>& Complex<T>::operator*=(const T& c);

现在,对于MathVector类,也会定义乘法:

template<typename T> MathVector<T>& MathVector<T>::operator*=(const MathVector<T>& c);

这适用于T=double,但对于T=Complex<double>,我希望有可能与double相乘,而无需先将其转换为Complex<double>(效率更高) )。

这一点因代码也应该在CUDA设备代码中工作而更加严重(为简洁起见,我省略了说明符__host__ __device__)。这意味着标准库工具没有帮助。

首先我想到了一个额外的模板参数

template<typename T, typename U> MathVector<T>& MathVector<T>::operator*=(const U& c);

但这对我来说似乎很危险,因为U可能比TT::value_type更多。 (实际上我首先在Complex类中也有这个参数 - 编译器无法再决定使用哪个模板,复杂类或MathVector类之一。)

第二个想法是使用模板专业化

template<typename T, typename U> MathVector<T>& MathVector<T>::operator*=(const U& c)
{
   static_assert(sizeof(T) == 0, "Error...");
}
template<typename T> MathVector<T>& MathVector<T>::operator*=(const typename T::value_type& c)
{
   // implementation
}

但这不再适用于基本类型!

我已经在C++ Operator Overloading for a Matrix Class with Both Real and Complex MatricesReturn double or complex from template function中看到了这个(或非常类似的)问题的解决方案,但它们是使用标准库以CUDA无法实现的方式解决的。

所以我的问题是:是否有一种方法可以重载使用基本类型的运算符以及为value_type提供但不为其他类型提供服务的类型 - 而不使用std:: nvcc的内容编译器会拒绝吗?

2 个答案:

答案 0 :(得分:2)

您可以制作operator*=非成员函数模板,并提供所有重载,使SFINAE生效。

template<typename T>
MathVector<T>& operator*=(MathVector<T>& m, const MathVector<T>& c);
template<typename T>
MathVector<T>& operator*=(MathVector<T>& m, const T& c);
template<typename T>
MathVector<T>& operator*=(MathVector<T>& m, const typename T::value_type& c);

然后将其称为:

MathVector<Complex<double>> m1;
m1 *= MathVector<Complex<double>>{};  // call the 1st one
m1 *= Complex<double>{};              // call the 2nd one
m1 *= 0.0;                            // call the 3rd one

MathVector<double> m2;
m2 *= MathVector<double>{};           // call the 1st one
m2 *= 0.0;                            // call the 2nd one

LIVE

答案 1 :(得分:2)

使用SFINAE和decltype,您可以执行类似(c ++ 11)的操作:

template<typename T, typename U>
auto
MathVector<T>::operator*=(const U& c)
-> decltype(void(std::declval<T&>() *= c), std::declval<MathVector<T>&>())
{
    // Your implementation
}