有没有办法使用numpy布尔数组索引像x = ['a','b','c']
这样的python列表?我目前收到以下错误:TypeError: only integer arrays with one element can be converted to an index
答案 0 :(得分:3)
通过[]
索引秘密调用对象的__getitem__
方法。对于用纯Python实现的对象,您可以使用任何适合您需要的函数覆盖此方法。但是列表是用C实现的,因此不允许替换list.__getitem__
。因此,没有直接的方法可以按照您的要求进行操作。
然而,您可以从列表中创建NumPy数组,然后对其执行NumPy样式的布尔索引:
import numpy as np
x = ['a', 'b', 'c']
mask = np.array([True, False, True])
x_arr = np.asarray(x, dtype=object)
output = x_arr[mask] # Get items
x_arr[mask] = ['new', 'values'] # Set items
不幸的是,np.asarray
无法简单地查看您的列表,因此只需复制列表。这意味着在为x
元素分配新值时,原始x_arr
不会更改。
如果你真的想要在列表上使用NumPy布尔索引的全部功能,你必须编写一个从头开始执行此操作的函数,并且你将无法使用[]
索引语法。
答案 1 :(得分:3)
In [304]: ['a','b','c'][[2,1,0]]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-304-c04b1f0621a3> in <module>()
----> 1 ['a','b','c'][[2,1,0]]
TypeError: list indices must be integers or slices, not list
列表理解路线
In [306]: [i for i,j in zip(['a','b','c'],[True, False, True]) if j]
Out[306]: ['a', 'c']
阵列路线
In [308]: np.array(['a','b','c'])[np.array([True, False, True])]
Out[308]:
array(['a', 'c'],
dtype='<U1')
返回列表:
In [309]: np.array(['a','b','c'])[np.array([True, False, True])].tolist()
Out[309]: ['a', 'c']
但是如果列表包含对象,请注意,而不是数字或字符串。这可能不会保留链接。
operator
模块具有itemgetter
功能
In [321]: operator.itemgetter(*[2,0,1])(list('abc'))
Out[321]: ('c', 'a', 'b')
但是在封面下它只是一个像迭代器这样的列表理解。而且我不会随便看到一个布尔版本。
答案 2 :(得分:3)
map(x.__getitem__,np.where(mask)[0])
或者如果你想要列表理解
[x[i] for i in np.where(mask)[0]]
这使您不必遍历整个列表,尤其是在mask
稀疏的情况下。
答案 3 :(得分:1)
你需要它作为一个列表吗?由于您想要使用numpy数组的索引行为,因此如果您实际使用numpy数组,那么读取您的代码的其他人会更有意义。
也许尝试使用dtype =&#39; a&#39;?例如,在下面的代码中,
x = sp.array(['a', 'b', 'c'], dtype='a')
print(x)
print(x=='c')
print(x[x=='c']).
这将返回以下数组,
['a' 'b' 'c']
[False False True]
['c'].
分配也会像你期望的那样工作,
x[x=='c'] = 'z'
print(x).
这将返回修改后的数组
['a' 'b' 'z'].
唯一关心的是数组的元素不能长于分配的长度。在这里,它被指定为dtype =&#39; a&#39;。你可以使用dtype =&#39; a5&#39;或者dtype =&#39; aN&#39;你想要的任何长度。数组的所有元素必须是短于最大长度的字符串。
如果你传递一个太长的字符串,它会切断结尾,如下例所示,dtype设置为&#39; a2&#39;:
x = sp.array(['abc', 'bcd', 'cde'], dtype='a2')
print(x),
给出,
['ab' 'bc' 'cd'].