如何在逐个元素的基础上找到哪个numpy数组包含最大值?

时间:2013-04-25 14:44:54

标签: python algorithm numpy

给定一个numpy数组列表,每个数组都有相同的维度,如何找到哪个数组在逐个元素的基础上包含最大值?

e.g。

import numpy as np
def find_index_where_max_occurs(my_list):
    # d = ...  something goes here ...
    return d

a=np.array([1,1,3,1])
b=np.array([3,1,1,1])
c=np.array([1,3,1,1])

my_list=[a,b,c]

array_of_indices_where_max_occurs = find_index_where_max_occurs(my_list)

# This is what I want:
# >>> print array_of_indices_where_max_occurs
# array([1,2,0,0])
# i.e. for the first element, the maximum value occurs in array b which is at index 1 in my_list.

非常感谢任何帮助......谢谢!

3 个答案:

答案 0 :(得分:3)

如果你想要一个数组,还有另一个选择:

>>> np.array((a, b, c)).argmax(axis=0)
array([1, 2, 0, 0])

所以:

def f(my_list):
    return np.array(my_list).argmax(axis=0)

这也适用于多维数组。

答案 1 :(得分:2)

为了它的乐趣,我意识到@Lev的原始答案比他的通用编辑更快,所以这是广义的堆叠版本比np.asarray版本快得多,但它不是很优雅。

np.concatenate((a[None,...], b[None,...], c[None,...]), axis=0).argmax(0)

那是:

def bystack(arrs):
    return np.concatenate([arr[None,...] for arr in arrs], axis=0).argmax(0)

一些解释:

我为每个数组添加了一个新轴:arr[None,...]相当于arr[np.newaxis,...],与arr[np.newaxis,:,:,:]相同,其中...扩展为适当的数字尺寸。这样做的原因是因为np.concatenate将沿着新维度进行堆叠,因为0位于前面,所以None会一直存在。

所以,例如:

In [286]: a
Out[286]: 
array([[0, 1],
       [2, 3]])

In [287]: b
Out[287]: 
array([[10, 11],
       [12, 13]])

In [288]: np.concatenate((a[None,...],b[None,...]),axis=0)
Out[288]: 
array([[[ 0,  1],
        [ 2,  3]],

       [[10, 11],
        [12, 13]]])

如果它有助于理解,这也会起作用:

np.concatenate((a[...,None], b[...,None], c[...,None]), axis=a.ndim).argmax(a.ndim)

现在新轴添加在最后,所以我们必须沿着最后一个轴堆叠和最大化,这将是a.ndim。对于abc为2d,我们可以这样做:

np.concatenate((a[:,:,None], b[:,:,None], c[:,:,None]), axis=2).argmax(2)

这相当于我在上面的评论中提到的dstack(如果数组中不存在,则dstack添加第三个轴以进行堆叠。)

测试:

N = 10
M = 2

a = np.random.random((N,)*M)
b = np.random.random((N,)*M)
c = np.random.random((N,)*M)

def bystack(arrs):
    return np.concatenate([arr[None,...] for arr in arrs], axis=0).argmax(0)

def byarray(arrs):
    return np.array(arrs).argmax(axis=0)

def byasarray(arrs):
    return np.asarray(arrs).argmax(axis=0)

def bylist(arrs):
    assert arrs[0].ndim == 1, "ndim must be 1"
    return [np.argmax(x) for x in zip(*arrs)]



In [240]: timeit bystack((a,b,c))
100000 loops, best of 3: 18.3 us per loop

In [241]: timeit byarray((a,b,c))
10000 loops, best of 3: 89.7 us per loop

In [242]: timeit byasarray((a,b,c))
10000 loops, best of 3: 90.0 us per loop

In [259]: timeit bylist((a,b,c))
1000 loops, best of 3: 267 us per loop

答案 2 :(得分:1)

[np.argmax(x) for x in zip(*my_list)]

嗯,这只是一个列表,但如果你愿意,你知道如何使它成为一个数组。 :)

要说明这一点:zip(*my_list)相当于zip(a,b,c),它为您提供了循环的生成器。循环中的每个步骤都会为您提供类似(a[i], b[i], c[i])的元组,其中i是循环中的步骤。然后,np.argmax为您提供具有最大值的元素的元组索引。