返回推力二元函数

时间:2016-02-03 11:57:09

标签: c++ templates thrust

我试图定义一个函数,该函数将根据字符串的内容返回所需的类型运算符。我试过这个,但它不起作用:

impl.cpp

template <typename T> thrust::binary_function<T,T,bool>
 get_filter_operator(const std::string &op)
    if (op == "!=")
        return thrust::not_equal_to<T>();
    else if (op == ">")
        return thrust::greater<T>();
    else if (op == "<")
        return thrust::less<T>();
    else if (op == ">=")
        return thrust::greater_equal<T>();
    else if (op == "<=")
        return thrust::less_equal<T>();
    else
    {
        return thrust::equal_to<T>();
    }

template thrust::binary_function<float,float,bool> get_filter_operator<float>(const std::string &);

impl.h

template <typename T> thrust::binary_function<T, T, bool> get_filter_operator(const std::string &op);

如何返回指向thrust::not_equal_to<int>()thrust::equal_to<int>()等任意函数的指针?我无法找到要返回的正确类型。

修改

根据要求,编译错误:

在'thrust :: binary_function&lt; T,T,bool&gt;的实例化中get_filter_operator(const string&amp;)[with T = float; std :: string = std :: basic_string&lt; char&gt;]':

错误:无法将'thrust :: equal_to&lt; float&gt;()'从'thrust :: equal_to&lt; float&gt;'转换为'thrust :: binary_function&lt; float,float,bool&gt;'          return thrust :: equal_to()

更新

很抱歉没有提到过这个:问题是我不能使用std :: function,因为它只适用于主机代码。我想使用推力二进制函数,以便我可以在GPU和CPU中使用它们。

3 个答案:

答案 0 :(得分:3)

  

如何返回指向任意函数的指针,如thrust :: not_equal_to()或thrust :: equal_to()?   我找不到要返回的正确类型

您尝试返回的每件事都是两个参数的函数, 每种类型T都返回bool。正确的返回类型是

std::function<bool(T, T)>

如:

#include <thrust/functional.h>
#include <functional>
#include <string>

template<typename T>
std::function<bool(T, T)>
get_filter_operator(const std::string &op)
{
    if (op == "!=")
        return thrust::not_equal_to<T>();
    else if (op == ">")
        return thrust::greater<T>();
    else if (op == "<")
        return thrust::less<T>();
    else if (op == ">=")
        return thrust::greater_equal<T>();
    else if (op == "<=")
        return thrust::less_equal<T>();
    else
    {
        return thrust::equal_to<T>();
    }
}

#include <iostream>

using namespace std;

int main()
{
    auto relop = get_filter_operator<int>("!=");
    cout << boolalpha << relop(1,0) << endl;
    cout << boolalpha << relop(1,1) << endl;

    return 0;
}

现在,您可能希望将您的评论重新重复到@MohamadElghawi:

  

是的,我知道这很有效,但问题在于我试图回归   thrust :: binary_function,而不是std

这可能是你想要做的,但这是错误的 试图去做一件不可能完成的事情。看看的定义 template<typename A1, typename A2, typename R> struct thrust::binary_function<thrust/functional>和相关文档中。注意:

  

binary_function是一个空基类:它不包含任何成员函数   或成员变量,但只有类型信息

特别是,thrust::binary_function<A1,A2,R>没有operator()。 它不可调用。它不能存储任何其他可调用对象(或 什么都没有)。另请参阅equal_tonot_equal_to的定义, 等在同一个文件中。 binary_function不是其中任何一个的平等基础。 没有任何转化为binary_function

请注意:

  

binary_function目前是C ++ STL类型的冗余   的std :: binary_function。我们在这里预留可能的额外费用   以后的功能。

std::binary_function在C ++ 11中已被弃用,将在C ++中删除17)。

thrust::binary_function<T,T,bool>不是您想要的。 std::function<bool(T, T)>

std::function<bool(int, int)> f = thrust::greater<int>(); 

使f封装了一个thrust::greater<int>

的可调用对象

<强>后来

  

这个问题是它只能在主机代码中使用吗?   推力二进制函数的优点在于它们可以在GPU和CPU中使用。

我认为你的印象可能是,例如

std::function<bool(int, int)> f = thrust::greater<int>();  /*A*/

采用thrust::greater<int>并以某种方式将其降级为a std::function<bool(int, int)>有类似但更受限制的std::function<bool(int, int)> foo (&#34; std&#34;)执行能力。

没有类似的情况。 bar只是一个容器 可以使用两个隐式参数调用的任何int 可转换为bool并返回隐式可转换为std::function<bool(int, int)> foo = bar; 的内容, 如果:

foo(i,j)

然后当你致电bool时,你会收到bar(i,j)的结果 执行 bar(i,j)。不是执行任何不同的任何结果的结果 来自/*A*/

因此,在上面的f中,thrust::greater<int>() 包含并调用的可调用事物是 推力二元函数; f。方法 由operator() thrust::greater<int>::operator() #include <thrust/functional.h> #include <functional> #include <iostream> using namespace std; int main() { auto thrust_greater_than_int = thrust::greater<int>(); std::function<bool(int, int)> f = thrust_greater_than_int; cout << "f " << (f.target<thrust::greater<int>>() ? "calls" : "does not call") << " a thrust::greater<int>" << endl; cout << "f " << (f.target<thrust::equal_to<int>>() ? "calls" : "does not call") << " a thrust::equal_to<int>" << endl; cout << "f " << (f.target<std::greater<int>>() ? "calls" : "does not call") << " an std::greater<int>" << endl; cout << "f " << (f.target<std::function<bool(int,int)>>() ? "calls" : "does not call") << " an std::function<bool(int,int)>" << endl; return 0; } 调用。

这是一个程序:

thrust::greater<int>

std::function<bool(int, int)> f存储在f calls a thrust::greater<int> f does not call a thrust::equal_to<int> f does not call an std::greater<int> f does not call an std::function<bool(int,int)> 中 然后通知你:

public string[,] DataArray;

答案 1 :(得分:1)

即使没有确切的答案,我也会把我最终使用的内容放在这里,万一有人需要类似的东西。

在.cuh文件中

Also, what resolution is best to design for?

然后像这样使用它: 在cpp文件中:

#include <cuda.h>
#include <cuda_runtime_api.h>

namespace BinaryFunction
{
    enum class ComparisonOperator
    {
      equal_to,
      not_equal_to,
      greater,
      less,
      greater_equal,
      less_equal
    };

    enum class BitwiseOperator
    {
      bit_and,
      bit_or
    };

    template<typename T>
    struct CompareFunction
    {
      __host__ __device__ T operator()(const T &lhs, const T &rhs, ComparisonOperator &op) const
      {
            switch (op)
            {
                case ComparisonOperator::equal_to:
                    return lhs==rhs;
                case ComparisonOperator::not_equal_to:
                    return lhs!=rhs;
                case ComparisonOperator::greater:
                    return lhs>rhs;
                case ComparisonOperator::less:
                    return lhs<rhs;
                case ComparisonOperator::greater_equal:
                    return lhs>=rhs;
                case ComparisonOperator::less_equal:
                    return lhs<=rhs;

            }
        }
    };

    template<typename T>
    struct BitwiseFunction
    {
      __host__ __device__ T operator()(const T &lhs, const T &rhs, BitwiseOperator &op) const
      {
        if (op==BitwiseOperator::bit_and)
          return lhs & rhs;
        else if (op==BitwiseOperator::bit_or)
          return lhs | rhs;
      }
    };
}

然后,在内核或正常函数中:

BinaryFunction::ComparisonOperator comp_op = BinaryFunction::ComparisonOperator::equal_to;

BinaryFunction::CompareFunction<int> comp_func;

答案 2 :(得分:0)

以下适用于我。

#include <iostream>
#include <functional>

template<class T>
std::function<bool(T, T)> GetOperator(const std::string& op)
{
    if (op == "!=")
        return std::not_equal_to<T>();
    else if (op == ">")
        return std::greater<T>();
    else if (op == "<")
        return std::less<T>();
    else if (op == ">=")
        return std::greater_equal<T>();
    else if (op == "<=")
        return std::less_equal<T>();
    else
    {
        return std::equal_to<T>();
    }
}

int main()
{
    auto op = GetOperator<int>(">");
    std::cout << op(1, 2) << '\n';
    return 0;
}