使用Numpy布尔数组的索引Python列表

时间:2016-11-24 20:08:52

标签: python numpy

有没有办法使用numpy布尔数组索引像x = ['a','b','c']这样的python列表?我目前收到以下错误:TypeError: only integer arrays with one element can be converted to an index

4 个答案:

答案 0 :(得分:3)

通过[]索引秘密调用对象的__getitem__方法。对于用纯Python实现的对象,您可以使用任何适合您需要的函数覆盖此方法。但是列表是用C实现的,因此不允许替换list.__getitem__。因此,没有直接的方法可以按照您的要求进行操作。

然而,您可以从列表中创建NumPy数组,然后对其执行NumPy样式的布尔索引:

import numpy as np

x = ['a', 'b', 'c']

mask = np.array([True, False, True])
x_arr = np.asarray(x, dtype=object)
output = x_arr[mask]  # Get items
x_arr[mask] = ['new', 'values']  # Set items

不幸的是,np.asarray无法简单地查看您的列表,因此只需复制列表。这意味着在为x元素分配新值时,原始x_arr不会更改。

如果你真的想要在列表上使用NumPy布尔索引的全部功能,你必须编写一个从头开始执行此操作的函数,并且你将无法使用[]索引语法。

答案 1 :(得分:3)

In [304]: ['a','b','c'][[2,1,0]]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-304-c04b1f0621a3> in <module>()
----> 1 ['a','b','c'][[2,1,0]]

TypeError: list indices must be integers or slices, not list

列表理解路线

In [306]: [i for i,j in zip(['a','b','c'],[True, False, True]) if j]
Out[306]: ['a', 'c']

阵列路线

In [308]: np.array(['a','b','c'])[np.array([True, False, True])]
Out[308]: 
array(['a', 'c'], 
      dtype='<U1')

返回列表:

In [309]: np.array(['a','b','c'])[np.array([True, False, True])].tolist()
Out[309]: ['a', 'c']

但是如果列表包含对象,请注意,而不是数字或字符串。这可能不会保留链接。

operator模块具有itemgetter功能

In [321]: operator.itemgetter(*[2,0,1])(list('abc'))
Out[321]: ('c', 'a', 'b')

但是在封面下它只是一个像迭代器这样的列表理解。而且我不会随便看到一个布尔版本。

答案 2 :(得分:3)

map(x.__getitem__,np.where(mask)[0])

或者如果你想要列表理解

[x[i] for i in np.where(mask)[0]]

这使您不必遍历整个列表,尤其是在mask稀疏的情况下。

答案 3 :(得分:1)

你需要它作为一个列表吗?由于您想要使用numpy数组的索引行为,因此如果您实际使用numpy数组,那么读取您的代码的其他人会更有意义。

也许尝试使用dtype =&#39; a&#39;?例如,在下面的代码中,

x = sp.array(['a', 'b', 'c'], dtype='a')
print(x)
print(x=='c')
print(x[x=='c']).

这将返回以下数组,

['a' 'b' 'c']
[False False  True]
['c'].

分配也会像你期望的那样工作,

x[x=='c'] = 'z'
print(x).  

这将返回修改后的数组

['a' 'b' 'z'].

唯一关心的是数组的元素不能长于分配的长度。在这里,它被指定为dtype =&#39; a&#39;。你可以使用dtype =&#39; a5&#39;或者dtype =&#39; aN&#39;你想要的任何长度。数组的所有元素必须是短于最大长度的字符串。

如果你传递一个太长的字符串,它会切断结尾,如下例所示,dtype设置为&#39; a2&#39;:

x = sp.array(['abc', 'bcd', 'cde'], dtype='a2')
print(x), 

给出,

['ab' 'bc' 'cd'].