有没有更快的方法来计算数组累积和中的值的索引?

时间:2018-01-29 12:37:59

标签: python numpy

我想找到x的累积和中第一个值的索引的索引等于或超过sum(x)/ 2

是否有内置函数或更快的方式来计算以下内容?

x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])

def half_life_idx(x):
    middle = sum(x) / 2
    for idx, val in enumerate(x):
        middle = middle - val 
        if middle <= 0:         
            break

    return idx

half_life_idx(x)

>> 1
编辑:我可以得到反馈我的问题有什么问题吗?我对即将到来的大量投票感到惊讶。

2 个答案:

答案 0 :(得分:2)

您可以结合使用cumsumsearchsorted方法来实现更快的版本:

def half_life_idx_ww(x):
    cs = np.cumsum(x)
    middle = cs[-1]/2
    return cs.searchsorted(middle)

例如,

In [167]: x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])

In [168]: half_life_idx(x), half_life_idx_ww(x)
Out[168]: (1, 1)

In [169]: w = np.random.gamma(1.5, size=200)

In [170]: half_life_idx(w), half_life_idx_ww(w)
Out[170]: (99, 99)

答案 1 :(得分:2)

另一种方法是使用np.argmax查看此示例的函数f1

import numpy as np

def f0(x):
    #leermeester's orginal method
    middle = sum(x) / 2
    for idx, val in enumerate(x):
        middle = middle - val 
        if middle <= 0:         
            break
    return idx

def f1(x):
    #my method using argmax
    cs = x.cumsum()
    return np.argmax(cs>cs[-1]/2)

def f2(x):
    #Warren Weckesser's method using searchsorted
    cs = np.cumsum(x)
    middle = cs[-1]/2
    return cs.searchsorted(middle)

以下是每种方法的一些基准:

print("small run")
x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])

%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))


print("larger run")
x = np.random.rand(int(1.0E3))

%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))

print("very large run")
x = np.random.rand(int(1.0E6))

%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))

#a print to make sure all give the same result
print(f0(x),f1(x),f2(x))

基准测试结果

small run
2.48 µs ± 41.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.47 µs ± 57.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.7 µs ± 49.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
larger run
184 µs ± 2.59 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.2 µs ± 51.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
5.01 µs ± 14.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
very large run
185 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.3 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.64 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
500260 500260 500260

结论:你的方法对于非常小的阵列来说是最快的,但是对于较大的阵列,它比得到的答案慢得多,而沃伦的解决方案始终比我的快30%。