非常慢的numba CUDA在python中

时间:2016-01-28 02:49:45

标签: python numba

我将这个简单的代码运行到numba cuda中,它发现非常慢。 想知道瓶颈吗?

   @cuda.jit('int32(float64,  float64, int32)', device=True)
   def mandelbrot_numbagpu(creal,cimag,maxiter):
       real = creal
       imag = cimag
       for n in range(maxiter):
           real2 = real*real
           imag2 = imag*imag
           if real2 + imag2 > 4.0:
               return n
           imag = 2* real*imag + cimag
           real = real2 - imag2 + creal
       return 0



   @cuda.jit
   def        mandelbrot_set_numbagpu(xmin,xmax,ymin,ymax,width,height,maxiter,n3,r1,r2):
      for i in range(width):
         for j in range(height):
            n3[i,j] = mandelbrot_numbagpu(r1[i],r2[j],maxiter)


   r1 = np.linspace(-2.0,0.5,1000, dtype=np.float )
   r2 = np.linspace(-1.25,1.25,1000, dtype=np.float)
   n3 = np.zeros((1000,1000),  dtype=np.uint8)

   %timeit mandelbrot_set_numbagpu(-2.0,0.5,-1.25,1.25,1000,1000,80,n3,r1,r2)
   #1 loops, best of 3: 4.84 s per loop

如果我在JIT上运行,它会快10倍!....

1 个答案:

答案 0 :(得分:3)

与Numba / CUDA(我认为对CUDA一般)的问题一般,你的函数不应该遍历数组。相反,他们应该处理单个数组元素,Numbda / CUDA处理程序将一大堆数组元素分配给一大堆GPU核心,因此一切都快速并行地进行。这是all documented

不幸的是,这意味着您无法将@jit更改为@cuda.jit,但您必须对其进行调整。

以下作品:

# mandelbrot_numbagpu as before...

# I've removed some of the useless arguments for simplicity
@cuda.jit
def mandelbrot_set_numbagpu(n3,r1,r2,maxiter):
    # numba provides this function for working out which element you're
    # supposed to be accessing
    i,j = cuda.grid(2)
    if i<n3.shape[0] and j<n3.shape[1]: # check we're in range
        # do work on a single element
        n3[i,j] = mandelbrot_numbagpu(r1[i],r2[j],maxiter)

然后将其称为

# you assign a number of threads, and split it into blocks
# this is all in the documentation!
import math
threadsperblock = (16,16)
blockspergrid_x = math.ceil(n3.shape[0] / threadsperblock[0])
blockspergrid_y = math.ceil(n3.shape[1] / threadsperblock[1])
blockspergrid = (blockspergrid_x, blockspergrid_y)

mandelbrot_set_numbagpu2[blockspergrid,threadsperblock](n3,r1,r2,80)
# n3, r1 and r2 are defined as before

在我的电脑上,这可以提高3800倍的速度。我不知道与等效的CPU程序相比如何。