具有保护条件的Python标量产品

时间:2017-01-03 14:20:16

标签: python arrays numpy

我有两个相同形状的数组:

a = numpy.array([7, 5, 0, 2, 9, 6, 4, 1])
b = numpy.array([1, 6, 3, 9, 1, 1, 3, 8])

我正在寻找一种很好的方法来创建ab的总和,但是计算会在达到限制l之前停止,然后返回a中用于该计算的元素。

所以在上面的例子中,如果我执行:

  

numpy.dot(a,b)

我得到90。但是,作为一个例子,如果l=50我想要的是一个行为如下的函数:

>>> foo(a,b,l)
[7,5]

2 个答案:

答案 0 :(得分:2)

使用cumsum获取元素乘法求和的一种方法,然后argmax获取限制乘积和的第一个索引,最后slicing转换为a -

a[:((a*b).cumsum()>l).argmax()]

示例运行 -

# Adding 2 as the third elem in 'a' for a better test
In [103]: a = np.array([7, 5, 2, 2, 9, 6, 4, 1]) 
     ...: b = np.array([1, 6, 3, 9, 1, 1, 3, 8])
     ...: 

In [104]: l = 50

In [105]: (a*b).cumsum()
Out[105]: array([ 7, 37, 43, 61, 70, 76, 88, 96])

In [106]: a[:((a*b).cumsum()>l).argmax()]
Out[106]: array([7, 5, 2])

运行时测试 -

In [116]: a = np.random.randint(0,4,(10000))

In [117]: b = np.random.randint(0,4,(10000))

In [118]: l = 10000

In [119]: %timeit a[np.where((a*b).cumsum() < l)] # @wildwilhelm's soln
10000 loops, best of 3: 92.8 µs per loop

In [121]: %timeit a[(a*b).cumsum() < l] # @wildwilhelm's soln w/o np.where
10000 loops, best of 3: 82.6 µs per loop

In [122]: %timeit a[:((a*b).cumsum()>l).argmax()]
10000 loops, best of 3: 71.3 µs per loop

答案 1 :(得分:2)

>>> import numpy as np
>>> l = 50
>>> a[np.where((a*b).cumsum() < l)]
array([7, 5, 0])