我使用cub::DeviceScan
功能,而sample code snippet有一个参数temp_storage_bytes
,用于分配内存(顺便提一下,代码片段永远不会释放)。< / p>
代码片段使用指向cub::DeviceScan
内存的指针调用NULL
函数,该指针触发它以计算函数所需的临时设备内存量,然后返回。必要的临时存储器分配有cudaMalloc
,并且重复指向该存储器的函数调用。然后使用cudaFree
(或可能应该)释放临时内存。
我在不同的浮点数组上重复多次设备扫描,但每个浮点数组的长度相同。
我的问题是,我可以假设temp_storage_bytes
总是相同的价值吗?如果是这样,我可以为许多函数调用执行单个cudaMalloc
和单个cudaFree
。
该示例不清楚如何确定所需的内存以及它是否可以针对给定长度的给定数组进行更改。
答案 0 :(得分:2)
如果您在相同长度的不同阵列上重复调用cub::DeviceScan::InclusiveScan
,则可以假设您只需要调用一次temp_storage_bytes
来确定所需的临时cub::DeviceScan::InclusiveScan
字节数。在下面的示例中,我在相同长度的不同数组上多次调用cub::DeviceScan::InclusiveScan
,并且只使用一次调用cub::DeviceScan::InclusiveScan
来确定临时大小的数量 -
// Ensure printing of CUDA runtime errors to console
#define CUB_STDERR
#include <stdio.h>
#include <algorithm> // std::generate
#include <cub/cub.cuh> // or equivalently <cub/device/device_scan.cuh>
#include <thrust\device_vector.h>
#include <thrust\host_vector.h>
void main(void)
{
// Declare, allocate, and initialize device pointers for input and output
int num_items = 7;
thrust::device_vector<int> d_in(num_items);
thrust::device_vector<int> d_out(num_items);
// Determine temporary device storage requirements for inclusive prefix sum
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, d_in.data(), d_out.data(), num_items);
// Allocate temporary storage for inclusive prefix sum
cudaMalloc(&d_temp_storage, temp_storage_bytes);
for (int k=0; k<10; k++) {
thrust::host_vector<int> h_in(num_items);
thrust::host_vector<int> h_out(num_items,0);
std::generate(h_in.begin(), h_in.end(), rand);
d_in = h_in;
// Run inclusive prefix sum
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, d_in.data(), d_out.data(), num_items);
int difference = 0;
int prev = 0;
for (int i=0; i<num_items; i++) {
h_out[i] = prev + h_in[i];
prev = h_out[i];
int val = d_out[i];
printf("%i %i %i %i\n",i,difference,h_out[i],d_out[i]);
difference = difference + abs(h_out[i] - d_out[i]);
}
if (difference == 0) printf("Test passed!\n");
else printf("A problem occurred!\n");
h_in.shrink_to_fit();
h_out.shrink_to_fit();
}
getchar();
}