我正在计算int8
s向量中最常见的数字。当我设置int
s的计数器数组时,Numba抱怨:
@jit(nopython=True)
def freq_int8(y):
"""Find most frequent number in array"""
count = np.zeros(256, dtype=int)
for val in y:
count[val] += 1
return ((np.argmax(count)+128) % 256) - 128
调用它我收到以下错误:
TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))
如果我删除dtype=int
它可以工作,我获得了不错的加速。然而,我很困惑为什么声明一组int
s不起作用。是否有一个已知的解决方法,是否有值得拥有的效率增益?
背景:我正在尝试削减一些重量级代码的微秒。我特别受到numpy.median
的伤害,并一直在调查Numba,但我正在努力改善median
。查找最常用的号码是median
的可接受替代方案,在这里我已经获得了一些性能。上面的numba代码也比numpy.bincount
快。
更新:在接受的答案中输入后,median
int8
个numpy.median
向量的实现。它比@jit(nopython=True)
def median_int8(y):
N2 = len(y)//2
count = np.zeros(256, dtype=np.int32)
for val in y:
count[val] += 1
cs = 0
for i in range(-128, 128):
cs += count[i]
if cs > N2:
return float(i)
elif cs == N2:
j = i+1
while count[j] == 0:
j += 1
return (i + j)/2
大约快一个数量级:
numpy
令人惊讶的是,短向量的性能差异更大,显然是由于>>> a = np.random.randint(-128, 128, 10)
>>> %timeit np.median(a)
The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 20.8 µs per loop
>>> %timeit median_int8(a)
The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 593 ns per loop
向量的开销:
vlookup
这个开销太大了,我想知道是不是有问题。
答案 0 :(得分:6)
只需快速说明,找到最常用的号码通常称为mode,它与中位数类似,因为它是 mean 。 ..在这种情况下compile("org.springframework.cloud:spring-cloud-starter-aws:1.1.0.RELEASE")
// if I don't add the line below, the annotation @MessageMapping is not found :(
// I would have expected that cloud-starter-aws would have taken care of it
compile("org.springframework.cloud:spring-cloud-aws-messaging:1.1.0.RELEASE")
// this has been added to fix an exception happening, please read below
compile("org.springframework.data:spring-data-commons:1.12.1.RELEASE")
会快得多。除非您的数据存在某些约束或特殊情况,there is no guarantee that the mode approximates the median。
如果您仍想计算整数列表的模式,正如您所提到的那样,np.bincount
应该足够了(如果numba更快,则不应该通过多):
np.mean
注意我已将count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128
参数添加到minlength
,只是为了返回代码中相同的256长度列表。但在实践中完全没有必要,因为您只希望np.bincount
,argmax
(没有np.bincount
)将返回一个列表,其长度是minlength
中的最大数量。< / p>
至于numba错误,将y
替换为dtype=int
可以解决问题。 dtype=np.int32
是一个python函数,您在numba标头中指定int
。如果您删除nopython
,则nopython
或dtype=int
也会有效(具有相同的效果)。