我正试图找到最快的方法来获得numpy的'where'语句在2D numpy数组上的功能;即,检索满足条件的索引。它比我使用的其他语言(例如,IDL,Matlab)慢得多。
我有一个cythonized函数,它在嵌套的for循环中遍历数组。速度几乎有一个数量级的增加,但如果可能的话,我想更多地提高性能。
TEST.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
我的cython_where.pyx程序:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
TEST.py的输出:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
我也尝试了numpy的argwhere
,其速度与where
一样快。我对numpy和cython很陌生,所以如果你有任何其他的想法来真正提高性能,我会全力以赴!
答案 0 :(得分:3)
提供内容:
Numpy可以在平顶阵列上加速,获得4倍的增益:
%timeit np.where(data==10)
1 loops, best of 3: 105 ms per loop
%timeit np.unravel_index(np.where(data.ravel()==10),data.shape)
10 loops, best of 3: 26.0 ms per loop
我认为你可以用它来优化你的cython代码,避免为每个单元格计算k=i*ncol+j
。
Numba提供了一个简单的替代方案:
from numba import jit
dtype=data.dtype
@jit(nopython=True)
def numbaeq(flatdata,x,nrow,ncol):
size=ncol*nrow
ix=np.empty(size,dtype=dtype)
jx=np.empty(size,dtype=dtype)
count=0
k=0
while k<size:
if flatdata[k]==x :
ix[count]=k//ncol
jx[count]=k%ncol
count+=1
k+=1
return ix[:count],jx[:count]
def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)
给出:
%timeit whereequal(data,10)
10 loops, best of 3: 20.2 ms per loop
在cython性能下,对于这类问题的numba并不是很好的优化。
k//ncol
和k%ncol
可以使用优化的divmod
操作同时计算。