VexCL:vexcl向量中的最大值索引

时间:2014-07-10 09:59:48

标签: c++ vexcl

如何在VexCL向量中找到最大值的索引?我可以找到最大值:

int h[] = {3, 2, 1, 5, 4};
vex::vector<int> d(ctx, 5);
vex::copy(h, d);

vex::Reductor<int, vex::MAX> max(ctx.queue());
int m = max(d);

哪个提供m = 5,但有没有办法找到最大值的索引ind = 3

1 个答案:

答案 0 :(得分:2)

你需要

  1. 在vexcl表达式中编码矢量值和矢量位置,
  2. 为vex :: Reductor创建自定义函子,它将根据第一个组件减少上述表达式。
  3. 以下是工作代码:

    #include <iostream>
    #include <vector>
    #include <vexcl/vexcl.hpp>
    
    // This function converts two integers to cl_int2
    VEX_FUNCTION(cl_int2, make_int2, (int, x)(int, y),
            int2 v = {x, y};
            return v;
            );
    
    // This struct compares OpenCL vector types by the first component.
    struct MAX0 {
        template <class Tn>
        struct impl {
            typedef typename vex::cl_scalar_of<Tn>::type T;
    
            // Initial value.
            static Tn initial() {
                Tn v;
    
                if (std::is_unsigned<T>::value)
                    v.s[0] = static_cast<T>(0);
                else
                    v.s[0] = -std::numeric_limits<T>::max();
    
                return v;
            }
    
            // Device-side function call operator
            struct device : vex::UserFunction<device, Tn(Tn, Tn)> {
                static std::string name() { return "MAX_" + vex::type_name<Tn>(); }
                static std::string body() { return "return prm1.x > prm2.x ? prm1 : prm2;"; }
            };
    
            // Host-side function call operator
            Tn operator()(Tn a, Tn b) const {
                return a.s[0] > b.s[0] ? a : b;
            }
        };
    };
    
    int main(int argc, char *argv[]) {
        vex::Context ctx( vex::Filter::Env );
    
        std::vector<int> h = {3, 2, 1, 5, 4};
        vex::vector<int> d(ctx, h);
    
        // Create reductor based on MAX0 operation,
        // then reduce an expression that encodes both value and position of a
        // vector element:
        vex::Reductor<cl_int2, MAX0> max(ctx);
    
        cl_int2 m = max(make_int2(d, vex::element_index()));
    
        std::cout << "max value of " << m.s[0] << " at position " << m.s[1] << std::endl;
    }
    

    此输出

    max value of 5 at position 3