我是CUDA的新手(大约一小时前安装了numba)。我想加快课堂内的这个功能。
def predict(self, X):
num_test = X.shape[0]
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
for i in range(num_test):
distances = np.sum(np.abs(self.Xtr-X[i, :]), axis=1)
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
print(i)
return Ypred
X是float32类型的2D数组,Ypred是int32类型的数组。我试图通过在函数上方插入以下行来加速它。
@vectorize(['int32(float32)'], target='cuda')
这给了我一个巨大的错误列表,但其重要部分似乎是:
TypeError: Failed at nopython (analyzing bytecode)
Signature mismatch: 1 argument types given, but function takes 2 arguments
虽然我确切地知道错误说的是什么,但我不知道如何解决它。那么......我该如何让它发挥作用?提前谢谢。
更新
在询问之前我应该做一个正确的谷歌搜索(我做了搜索,但我使用了' object'而不是' class',这对我没用结果)。 documentation给了我很多帮助,但现在我的脸上出现了这些错误,而且我不知道该怎么做。
numba.errors.LoweringError: Failed at nopython (nopython mode backend)
Can only insert float* at [4] in {i8*, i8*, i64, i64, float*, [2 x i64],
[2 x i64]}: got double*
File "main.py", line 40
[1] During: lowering "(self).Xtr = X" at D:/myStuff/DL/Week 3/1/main.py (40)
[2] During: resolving callee type:
BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'train') f for instance.jitclass.NearestNeighbours#24f01184f58<Xtr:array(float32, 2d, A),ytr:array(int32, 1d, A)>)
[3] During: typing of call at <string> (3)
--%<-----------------------------------------------------------------
File "<string>", line 3
这里是整个班级的当前状态:
spec = [("Xtr", float32[:, :]), ("ytr", int32[:])]
@jitclass(spec)
class NearestNeighbours(object):
def __init__(self):
pass
def train(self, X, y):
self.Xtr = X #line 40
self.ytr = y
def predict(self, X):
num_test = X.shape[0]
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
for i in range(num_test):
distances = np.sum(np.abs(self.Xtr-X[i, :]), axis=1)
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
print(i)
return Ypred
更新2: 放弃jitting这个类并尝试将predict链接到它的外部克隆。使用一个空的jit似乎工作,但链接到cuda(速度)导致各种奇特的错误。 我今天休息一下,如果我以某种方式解决问题,我会回答我自己的问题。直到几个小时前,我认为GPU加速就像添加一个额外的库或切换到不同的编译器或其他东西一样简单......但是男人......我不知道我是否会这样做颠簸。
答案 0 :(得分:1)
据我所知,你的函数只依赖于for(int i=2;i<range;i++){
numberslist.add(numberslist.get(i-1)+numberslist.get(i-2))}
所以没有理由把它作为一个类中的函数。将其声明为静态@staticmethod
(link)或将其从类 - 范围中取出。
Presto:只剩下1个函数参数。