在Numba中使用cuda.jit的正确方法

时间:2019-08-23 02:40:40

标签: cuda numba

试图找出如何在Numba的cuda.jit中进行矩阵矢量乘法,但是我得到了错误的答案

import numpy as np
import numba
from numba import cuda
m = 2 
n = 3
@cuda.jit('void(f4[:,:], f4[:], f4[:])')
def cu_matrix_vector(A, b, c):
    row = cuda.grid(1)
    if (row < m):
        temp = 0
        for i in range(n):
            temp += A[row, i] * b[i]
        c[row] = temp

A = np.array([[1, -1, 2], [0, -3, 1]], dtype=np.float32)
B = np.array([2, 1, 0], dtype=np.float32)
C = np.empty((2,))

dA = cuda.to_device(A)
dB = cuda.to_device(B)
dC = cuda.to_device(C)

cu_matrix_vector[(m+511)/512, 512](dA, dB, dC)
print(C)

答案是错误的,无法弄清楚我做错了什么。 请帮助,谢谢。

1 个答案:

答案 0 :(得分:2)

您的代码中至少有2个错误:

  1. numba默认将浮点变量设置为与python默认使用的大小相同,即64位浮点。如果您在签名中指定32位浮点:

    @cuda.jit('void(f4[:,:], f4[:], f4[:])')
    

    传递32位浮点变量很重要。您的C(因此dC)与此不匹配。我们可以使用与AB相同的方法来修复它:

    C = np.empty((2,), dtype=np.float32)
    
  2. numba和CUDA需要在设备和主机之间来回移动数据。从主机打印设备结果时,确保在打印之前已将这些结果复制回(从dC)是很重要的。如果您打印C而不是dC,numba不会自动为您执行此操作。我们可以这样解决:

    print(dC.copy_to_host())
    

有了这些更改,您的代码将为我打印出预期的结果。