我有以下数组:
a = np.array([6,5,4,3,4,5,6])
现在我要获取所有大于4但索引值也大于2的元素。 我发现的方法如下:
a[2:][a[2:]>4]
是否有更好或更具可读性的方法来完成此任务?
更新: 这是简化版本。实际上,索引是通过对以下几个变量进行算术运算来完成的:
a[len(trainPredict)+(look_back*2)+1:][a[len(trainPredict)+(look_back*2)+1:]>4]
trainPredict
由一个numpy数组组成,look_back
是一个整数。
我想看看是否有确定的方法或其他人如何做到。
答案 0 :(得分:3)
如果您担心切片的复杂性和/或条件的数量,可以随时将它们分开:
a = np.array([6,5,4,3,4,5,6])
a_slice = a[2:]
cond_1 = a_slice > 4
res = a_slice[cond_1]
您的示例是否非常简化?对于更复杂的操作,可能会有更好的解决方案。
答案 1 :(得分:1)
@AlexanderCécile's answer不仅比您发布的那条衬线更清晰,而且还消除了临时数组的冗余计算。尽管如此,它似乎并没有比您原来的方法快。
下面的时间都是在预先设置为
的情况下运行的import numpy as np
np.random.seed(0xDEADBEEF)
a = np.random.randint(8, size=N)
N
从1e3到1e8的系数是10。我尝试了四种代码变体:
result = a[2:][a[2:] > 4]
s = a[2:]; result = s[s > 4]
result = a[np.flatnonzero(a[2:]) + 2]
result = a[(a > 4) & (np.arange(a.size) >= 2)]
在所有情况下,计时都是通过在命令行上运行来获取的
python -m timeit -s 'import numpy as np; np.random.seed(0xDEADBEEF); a = np.random.randint(8, size=N)' '<X>'
在这里,N
是3到8之间的10的幂,而<X>
是上面的表达式之一。时间如下:
方法#1和#2实际上是无法区分的。令人惊讶的是,在〜5e3和〜1e6元素之间的范围内,方法3似乎略有增加,但明显更快。我通常不会期望花哨的索引。方法4当然是最慢的。
以下是数据,出于完整性考虑:
CodePope AlexanderCécile MadPhysicist1 MadPhysicist2
1000 3.77e-06 3.69e-06 5.48e-06 6.52e-06
10000 4.6e-05 4.59e-05 3.97e-05 5.93e-05
100000 0.000484 0.000483 0.0004 0.000592
1000000 0.00513 0.00515 0.00503 0.00675
10000000 0.0529 0.0525 0.0617 0.102
100000000 0.657 0.658 0.782 1.09