我想找到x的累积和中第一个值的索引的索引等于或超过sum(x)/ 2
是否有内置函数或更快的方式来计算以下内容?
x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])
def half_life_idx(x):
middle = sum(x) / 2
for idx, val in enumerate(x):
middle = middle - val
if middle <= 0:
break
return idx
half_life_idx(x)
>> 1
编辑:我可以得到反馈我的问题有什么问题吗?我对即将到来的大量投票感到惊讶。
答案 0 :(得分:2)
您可以结合使用cumsum
和searchsorted
方法来实现更快的版本:
def half_life_idx_ww(x):
cs = np.cumsum(x)
middle = cs[-1]/2
return cs.searchsorted(middle)
例如,
In [167]: x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])
In [168]: half_life_idx(x), half_life_idx_ww(x)
Out[168]: (1, 1)
In [169]: w = np.random.gamma(1.5, size=200)
In [170]: half_life_idx(w), half_life_idx_ww(w)
Out[170]: (99, 99)
答案 1 :(得分:2)
另一种方法是使用np.argmax
查看此示例的函数f1
:
import numpy as np
def f0(x):
#leermeester's orginal method
middle = sum(x) / 2
for idx, val in enumerate(x):
middle = middle - val
if middle <= 0:
break
return idx
def f1(x):
#my method using argmax
cs = x.cumsum()
return np.argmax(cs>cs[-1]/2)
def f2(x):
#Warren Weckesser's method using searchsorted
cs = np.cumsum(x)
middle = cs[-1]/2
return cs.searchsorted(middle)
以下是每种方法的一些基准:
print("small run")
x = np.array([67, 51, 42, 37, 21, 10, 2, 2, 1, 1, 1])
%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))
print("larger run")
x = np.random.rand(int(1.0E3))
%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))
print("very large run")
x = np.random.rand(int(1.0E6))
%timeit(f0(x))
%timeit(f1(x))
%timeit(f2(x))
#a print to make sure all give the same result
print(f0(x),f1(x),f2(x))
基准测试结果:
small run
2.48 µs ± 41.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.47 µs ± 57.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.7 µs ± 49.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
larger run
184 µs ± 2.59 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.2 µs ± 51.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
5.01 µs ± 14.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
very large run
185 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.3 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.64 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
500260 500260 500260
结论:你的方法对于非常小的阵列来说是最快的,但是对于较大的阵列,它比得到的答案慢得多,而沃伦的解决方案始终比我的快30%。