我正在尝试使用Thrust减少值数组的最小值和最大值,我似乎陷入困境。给定一系列浮点数我想要的是在一次传递中减少它们的最小值和最大值,但是使用推力的减少方法我会得到所有模板编译错误的母亲(或至少是阿姨)。
我的原始代码包含5个值列表,这些值分布在2个我想减少的float4数组中,但我已经将它简化为这个简短的例子。
struct ReduceMinMax {
__host__ __device__
float2 operator()(float lhs, float rhs) {
return make_float2(Min(lhs, rhs), Max(lhs, rhs));
}
};
int main(int argc, char *argv[]){
thrust::device_vector<float> hat(4);
hat[0] = 3;
hat[1] = 5;
hat[2] = 6;
hat[3] = 1;
ReduceMinMax binary_op_of_dooooom;
thrust::reduce(hat.begin(), hat.end(), 4.0f, binary_op_of_dooooom);
}
如果我把它分成2个减少而不是当然有效。我的问题是:是否有可能通过推力减少一次通过的最小值和最大值以及如何?如果没有,那么实现减少的最有效方法是什么?转换迭代器会帮助我(如果是这样,那么减少会减少一次吗?)
其他一些信息: 我正在使用Thrust 1.5(由CUDA 4.2.7提供) 我的实际代码是使用reduce_by_key,而不仅仅是reduce。 我在写这个问题的过程中发现了transform_reduce,但那个问题没有考虑到密钥。
答案 0 :(得分:3)
正如talonmies所指出的,你的缩减不会编译,因为thrust::reduce
期望二元运算符的参数类型与其结果类型匹配,但ReduceMinMax
的参数类型是float
,而其结果类型为float2
。
thrust::minmax_element
直接实施此操作,但如有必要,您可以使用thrust::inner_product
来实现缩减,这会概括thrust::reduce
:
#include <thrust/inner_product.h>
#include <thrust/device_vector.h>
#include <thrust/extrema.h>
#include <cassert>
struct minmax_float
{
__host__ __device__
float2 operator()(float lhs, float rhs)
{
return make_float2(thrust::min(lhs, rhs), thrust::max(lhs, rhs));
}
};
struct minmax_float2
{
__host__ __device__
float2 operator()(float2 lhs, float2 rhs)
{
return make_float2(thrust::min(lhs.x, rhs.x), thrust::max(lhs.y, rhs.y));
}
};
float2 minmax1(const thrust::device_vector<float> &x)
{
return thrust::inner_product(x.begin(), x.end(), x.begin(), make_float2(4.0, 4.0f), minmax_float2(), minmax_float());
}
float2 minmax2(const thrust::device_vector<float> &x)
{
using namespace thrust;
pair<device_vector<float>::const_iterator, device_vector<float>::const_iterator> ptr_to_result;
ptr_to_result = minmax_element(x.begin(), x.end());
return make_float2(*ptr_to_result.first, *ptr_to_result.second);
}
int main()
{
thrust::device_vector<float> hat(4);
hat[0] = 3;
hat[1] = 5;
hat[2] = 6;
hat[3] = 1;
float2 result1 = minmax1(hat);
float2 result2 = minmax2(hat);
assert(result1.x == result2.x);
assert(result1.y == result2.y);
}