我正在尝试使用Anaconda Accelerate编写一个在2d数组上执行逐元素加法,减法,乘法或除法的函数,但我写的函数一直得到错误的答案。不知道发生了什么。
import numpy as np
from numba import cuda
#Define functions
ADD, SUB, MUL, DIV = 1, 2, 3, 4
@cuda.jit('void(complex64[:,:], complex64[:,:], int8)')
def math_inplace_2d_cuda(a, b, operation):
m, n = a.shape[0], a.shape[1]
i, j = cuda.grid(2)
if i < m and j < n:
if operation == ADD: a[i, j] += b[i, j]
if operation == SUB: a[i, j] -= b[i, j]
if operation == MUL: a[i, j] *= b[i, j]
if operation == DIV: a[i, j] /= b[i, j]
def math_inplace_2d_host(a, b, operation):
m, n = a.shape[0], a.shape[1]
for i in range(m):
for j in range(n):
if operation == ADD: a[i, j] += b[i, j]
if operation == SUB: a[i, j] -= b[i, j]
if operation == MUL: a[i, j] *= b[i, j]
if operation == DIV: a[i, j] /= b[i, j]
#Create arrays
a = np.array([[1., 2], [3, 4]])
b = a.copy()*2
a_dev = cuda.to_device(a)
b_dev = cuda.to_device(b)
#Threading
threadperblock = 32, 8
def best_grid_size(size, tpb):
bpg = np.ceil(np.array(size, dtype=np.float) / tpb).astype(np.int).tolist()
return tuple(bpg)
blockpergrid = best_grid_size(a_dev.shape, threadperblock)
stream = cuda.stream()
#Do operation
op = ADD
math_inplace_2d_host(a, b, op)
math_inplace_2d_cuda[blockpergrid, threadperblock, stream](a_dev, b_dev, op)
print '\nhost\n', a
print '\ndevice\n', a_dev.copy_to_host()
此程序带有a和b数组的提供值,导致此输出(主机和设备阵列应该相同):
host
[[ 3. 6.]
[ 9. 12.]]
device
[[ 384. 768.]
[ 1024. 1536.]]
当我尝试减去时,我明白了:
host
[[-1. -2.]
[-3. -4.]]
device
[[ -4.65661287e-10 -1.19209290e-07]
[ -1.19209290e-07 -1.19209290e-07]]
用于乘法:
host
[[ 2. 8.]
[ 18. 32.]]
device
[[ 1.59512330e-314 1.59615943e-314]
[ 1.59672607e-314 1.59732508e-314]]
对于师:
host
[[ 0.5 0.5]
[ 0.5 0.5]]
device
[[ 5.25836359e-315 5.25433420e-315]
[ 5.25481893e-315 5.25525520e-315]]
答案 0 :(得分:1)
如果我将jit
签名更改为:
@cuda.jit('void(float64[:,:], float64[:,:], int64)')
或者如果我将a
和op
的定义更改为:
a = np.array([[1., 2], [3, 4]]).astype(np.complex64)
...
op = np.int8(ADD)
在后一种情况下,op为ADD
,我得到:
host
[[ 3.+0.j 6.+0.j]
[ 9.+0.j 12.+0.j]]
device
[[ 3.+0.j 6.+0.j]
[ 9.+0.j 12.+0.j]]
我原本期望从Numba发出类型错误,但它似乎默默地投射并做错了什么。也许在Numba google小组上提出这个问题。