在列表中查找numpy数组的索引

时间:2016-12-21 21:12:59

标签: python arrays list numpy

import numpy as np
foo = [1, "hello", np.array([[1,2,3]]) ]

我希望

foo.index( np.array([[1,2,3]]) ) 

返回

2

但我得到了

  

ValueError:具有多个元素的数组的真值   暧昧。使用a.any()或a.all()

什么比我目前的解决方案更好?这似乎效率低下。

def find_index_of_array(list, array):
    for i in range(len(list)):
        if np.all(list[i]==array):
            return i

find_index_of_array(foo, np.array([[1,2,3]]) )
# 2

5 个答案:

答案 0 :(得分:12)

这里出错的原因显然是因为numpy的ndarray会覆盖==来返回数组而不是布尔值。

AFAIK,这里没有简单的解决方案。只要中np.all(val == array)位有效,以下内容就可以正常工作。

next((i for i, val in enumerate(lst) if np.all(val == array)), -1)

该位是否有效取决于数组中的其他元素是什么,以及它们是否可以与numpy数组进行比较。

答案 1 :(得分:2)

为了提高性能,您可能只想处理输入列表中的NumPy数组。因此,我们可以在进入循环之前进行类型检查并索引到数组元素。

因此,实现将是 -

def find_index_of_array_v2(list1, array1):
    idx = np.nonzero([type(i).__module__ == np.__name__ for i in list1])[0]
    for i in idx:
        if np.all(list1[i]==array1):
            return i

答案 2 :(得分:2)

这个怎么样?

arr = np.array([[1,2,3]])
foo = np.array([1, 'hello', arr], dtype=np.object)

# if foo array is of heterogeneous elements (str, int, array)
[idx for idx, el in enumerate(foo) if type(el) == type(arr)]

# if foo array has only numpy arrays in it
[idx for idx, el in enumerate(foo) if np.array_equal(el, arr)]

<强>输出:

[2]

注意:即使foo是列表,这也会有效。我只是把它作为numpy数组放在这里。

答案 3 :(得分:2)

这里的问题(你可能已经知道但只是重复一遍)是list.index的工作方式:

for idx, item in enumerate(your_list):
    if item == wanted_item:
        return idx

if item == wanted_item是问题,因为它隐式地将item == wanted_item转换为布尔值。但numpy.ndarray(除非它是标量)会引发此ValueError

  

ValueError:具有多个元素的数组的真值是不明确的。使用a.any()或a.all()

解决方案1:适配器(瘦包装器)类

每当我需要使用像numpy.ndarray这样的python函数时,我通常会在list.index周围使用一个瘦包装器(适配器):

class ArrayWrapper(object):

    __slots__ = ["_array"]  # minimizes the memory footprint of the class.

    def __init__(self, array):
        self._array = array

    def __eq__(self, other_array):
        # array_equal also makes sure the shape is identical!
        # If you don't mind broadcasting you can also use
        # np.all(self._array == other_array)
        return np.array_equal(self._array, other_array)

    def __array__(self):
        # This makes sure that `np.asarray` works and quite fast.
        return self._array

    def __repr__(self):
        return repr(self._array)

这些瘦包装比使用一些enumerate循环或理解手动更昂贵,但您不必重新实现python函数。假设列表只包含numpy-arrays(否则你需要进行一些if ... else ...检查):

list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]

在此步骤之后,您可以使用此列表中的所有python函数:

>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1

这些包装器不再是numpy-arrays,但你有薄包装器,所以额外的列表非常小。因此,根据您的需要,您可以保留包装列表和原始列表,并选择执行操作,例如,您现在也可以list.count相同的数组:

>>> list_of_wrapped_arrays.count(np.ones((3)))
2

list.remove

>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]), 
 array([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]]), 
 array([ 1.,  1.,  1.])]

解决方案2:子类和ndarray.view

此方法使用numpy.array的显式子类。它的优点是您可以获得所有内置的数组功能,并且只修改请求的操作(__eq__):

class ArrayWrapper(np.ndarray):
    def __eq__(self, other_array):
        return np.array_equal(self, other_array)

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]

>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]

>>> view_list.index(np.array([2,2,2]))
1

您再次以这种方式获得大多数列表方法:list.removelist.count除了list.index

但是,如果某些操作隐式使用__eq__,则此方法可能会产生微妙的行为。您始终可以使用np.asarray.view(np.ndarray)重新解释为简单的numpy数组:

>>> view_list[1]
ArrayWrapper([ 2.,  2.,  2.])

>>> view_list[1].view(np.ndarray)
array([ 2.,  2.,  2.])

>>> np.asarray(view_list[1])
array([ 2.,  2.,  2.])

替代方法:覆盖__bool__(或__nonzero__ for python 2)

您可以覆盖__eq____bool__

,而不是在__nonzero__方法中解决问题。
class ArrayWrapper(np.ndarray):
    # This could also be done in the adapter solution.
    def __bool__(self):
        return bool(np.all(self))

    __nonzero__ = __bool__

这再次使list.index像预期一样工作:

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1

但这肯定会修改更多行为!例如:

>>> if ArrayWrapper([1,2,3]):
...     print('that was previously impossible!')
that was previously impossible!

答案 4 :(得分:0)

这应该做的工作:

[i for i,j in enumerate(foo) if j.__class__.__name__=='ndarray']
[2]