如何在金属中制作模板化计算内核

时间:2019-06-20 13:51:25

标签: macos gpu gpgpu metal

我一直在编写一些Metal计算内核。因此,我编写了一个带有以下声明的内核:

kernel void
myKernel(const device uint32_t *inData [[buffer(MyKernelIn)]],
        device uint32_t *outData [[buffer(MyKernelOut)]],
        uint2                          gid       [[thread_position_in_grid]],
        uint2 thread_position_in_threadgroup         [[thread_position_in_threadgroup]],
        uint2       threads_per_threadgroup      [[threads_per_threadgroup]],
        uint2 threadgroup_position_in_grid       [[threadgroup_position_in_grid]]) 
{ }

现在,我想写一个这样的变量,它的类型为inDatauint8_t的{​​{1}},我该怎么做?

我可以考虑这样做的可能方法:

  1. 使用不同的名称复制我的内核。 (不可扩展)
  2. 传递一些标志,基于该标志,我可以在内核中添加切换用例,以便在任何时候读写floatinData中的任何内存位置时都可以使用。这意味着我创建的任何临时数据也将使用此类逻辑进行转换。 (这将再次在内核代码中引发很多间接调用,不确定是否会影响我的性能)

还有更好的方法吗?我看到Metal Performance Shaders正在outData上工作,它指定了MTLTexture,并且基于pixelFormat,MPS可以处理多种数据类型。关于如何做到的任何见解?

谢谢!

1 个答案:

答案 0 :(得分:2)

一种可行的方法是:

  • inData声明为void*
  • 在内核着色器的主体中,调用模板函数并传递参数。模板函数将按所需的类型进行模板化,并会收到inData作为指向该类型的指针。

可以使用输入参数动态选择要调用的模板函数的哪个变体。但是更好的方法可能是使用函数常量进行选择。这样,选择就被编译进来了。

所以,像这样:

constant int variant [[function_constant(0)]];

template<typename T> void
work(const device void *inData,
     device uint32_t *outData,
     uint2 gid,
     uint2 thread_position_in_threadgroup,
     uint2 threads_per_threadgroup,
     uint2 threadgroup_position_in_grid) 
{
    const device T *data = static_cast<const device T*>(inData);
    // ...
}

kernel void
myKernel(const device void *inData              [[buffer(MyKernelIn)]],
         device uint32_t *outData               [[buffer(MyKernelOut)]],
         uint2 gid                              [[thread_position_in_grid]],
         uint2 thread_position_in_threadgroup   [[thread_position_in_threadgroup]],
         uint2 threads_per_threadgroup          [[threads_per_threadgroup]],
         uint2 threadgroup_position_in_grid     [[threadgroup_position_in_grid]]) 
{
    if (variant == 0)
        work<uint32_t>(inData, outData, gid, thread_position_in_threadgroup,
                       threads_per_threadgroup, threadgroup_position_in_grid);
    else if (variant == 1)
        work<uint8_t>(inData, outData, gid, thread_position_in_threadgroup,
                      threads_per_threadgroup, threadgroup_position_in_grid);
    else
        work<float>(inData, outData, gid, thread_position_in_threadgroup,
                    threads_per_threadgroup, threadgroup_position_in_grid);
}