获取python中列组合及其各自索引的乘法的最大值

时间:2016-11-05 07:19:44

标签: python arrays performance numpy vectorization

我有numpyM*N维数组,其中数组的每个元素都是float,其值介于0-1之间。

输入:为简单起见,我们考虑一个3 * 4数组:

a=np.array([
[0.1, 0.2, 0.3, 0.6],
[0.3, 0.4, 0.8, 0.7],
[0.5, 0.6, 0.2, 0.1]
])

我想一次考虑3列(比如第一次迭代的col 0,1,2和第二次的1,2,3)并获得3列的所有可能组合的乘法的最大值并且也得到各自价值的指数。

在这种情况下,我应该得到0.5*0.6*0.8=0.24的最大值和给出最大值的值行的索引:(2,2,1)

输出:[[0.24,(2,2,1)],[0.336,(2,1,1)]]

我可以使用循环执行此操作,但我想避免它们,因为它会影响运行时间,无论如何我可以在numpy中执行此操作吗?

1 个答案:

答案 0 :(得分:1)

这是一种使用NumPy strides的方法,对于这种滑动窗口操作非常有效,因为它creates a view into the array而不实际制作副本 -

N = 3 # Window size
m,n = a.strides
p,q = a.shape
a3D = np.lib.stride_tricks.as_strided(a,shape=(p, q-N +1, N),strides=(m,n,n))
out1 = a3D.argmax(0)
out2 = a3D.max(0).prod(1)

示例运行 -

In [69]: a
Out[69]: 
array([[ 0.1,  0.2,  0.3,  0.6],
       [ 0.3,  0.4,  0.8,  0.7],
       [ 0.5,  0.6,  0.2,  0.1]])

In [70]: out1
Out[70]: 
array([[2, 2, 1],
       [2, 1, 1]])

In [71]: out2
Out[71]: array([ 0.24 ,  0.336])

如果需要,我们可以将这两个输出压缩在一起 -

In [75]: zip(out2,map(tuple,out1))
Out[75]: [(0.23999999999999999, (2, 2, 1)), (0.33599999999999997, (2, 1, 1))]