我一直在编写一些Metal计算内核。因此,我编写了一个带有以下声明的内核:
kernel void
myKernel(const device uint32_t *inData [[buffer(MyKernelIn)]],
device uint32_t *outData [[buffer(MyKernelOut)]],
uint2 gid [[thread_position_in_grid]],
uint2 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
uint2 threads_per_threadgroup [[threads_per_threadgroup]],
uint2 threadgroup_position_in_grid [[threadgroup_position_in_grid]])
{ }
现在,我想写一个这样的变量,它的类型为inData
和uint8_t
的{{1}},我该怎么做?
我可以考虑这样做的可能方法:
float
和inData
中的任何内存位置时都可以使用。这意味着我创建的任何临时数据也将使用此类逻辑进行转换。 (这将再次在内核代码中引发很多间接调用,不确定是否会影响我的性能)还有更好的方法吗?我看到Metal Performance Shaders正在outData
上工作,它指定了MTLTexture
,并且基于pixelFormat
,MPS可以处理多种数据类型。关于如何做到的任何见解?
谢谢!
答案 0 :(得分:2)
一种可行的方法是:
inData
声明为void*
inData
作为指向该类型的指针。您可以使用输入参数动态选择要调用的模板函数的哪个变体。但是更好的方法可能是使用函数常量进行选择。这样,选择就被编译进来了。
所以,像这样:
constant int variant [[function_constant(0)]];
template<typename T> void
work(const device void *inData,
device uint32_t *outData,
uint2 gid,
uint2 thread_position_in_threadgroup,
uint2 threads_per_threadgroup,
uint2 threadgroup_position_in_grid)
{
const device T *data = static_cast<const device T*>(inData);
// ...
}
kernel void
myKernel(const device void *inData [[buffer(MyKernelIn)]],
device uint32_t *outData [[buffer(MyKernelOut)]],
uint2 gid [[thread_position_in_grid]],
uint2 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
uint2 threads_per_threadgroup [[threads_per_threadgroup]],
uint2 threadgroup_position_in_grid [[threadgroup_position_in_grid]])
{
if (variant == 0)
work<uint32_t>(inData, outData, gid, thread_position_in_threadgroup,
threads_per_threadgroup, threadgroup_position_in_grid);
else if (variant == 1)
work<uint8_t>(inData, outData, gid, thread_position_in_threadgroup,
threads_per_threadgroup, threadgroup_position_in_grid);
else
work<float>(inData, outData, gid, thread_position_in_threadgroup,
threads_per_threadgroup, threadgroup_position_in_grid);
}