我有以下numpy数组:
import numpy as np
a = np.asarray([887, 895, 903, 911, 920, 928, 936,
944, 952, 961, 969, 977, 985, 905,
914, 924, 934, 944, 954, 965, 975, 986, 996, 1007])
值不断增加,直至985
,然后降至905
。从这里开始,价值再次开始增加。
我需要一个检测此丢弃的函数,并删除所有大于它所删除值的数组元素,使剩余值仍然单调递增(要删除的值为粗体):
[887,895,903, 911,920,928,936,
944,952,961,969,977,985, 905,
914,924,934,944,954,965,975,986,996,1007]
期望的结果如下:
[887, 895, 903, 905,
914, 924, 934, 944, 954, 965, 975, 986, 996, 1007]
我怎么能这样做?
答案 0 :(得分:1)
可能有更优雅的解决方案,但这似乎有效:
# get index where you observe the drop
ind_drop = np.where(np.diff(a) < 0)[0] + 1 # or np.argmin(np.diff(a)) + 1
# get index from start of the range which should be deleted
ind_low = np.argmin(a < a[ind_drop])
# delete the requested range
a_new = np.delete(a, np.arange(ind_low, ind_drop, 1))
产生
array([ 887., 895., 903., 905., 914., 924., 934., 944.,
954., 965., 975., 986., 996., 1007.])
一些解释:
必须找到应该切割数组的索引。第二个指数ind_drop
在那里我们观察到下降,即两个元素之间的差异首次变为负值。
np.diff(a)
array([ 8, 8, 8, 9, 8, 8, 8, 8, 9, 8, 8, 8, -80,
9, 10, 10, 10, 10, 11, 10, 11, 10, 11])
我们可以使用布尔数组
来获取此索引np.diff(a) < 0
array([False, False, False, False, False, False, False, False, False,
False, False, False, True, False, False, False, False, False,
False, False, False, False, False], dtype=bool)
并应用np.where
。
np.where(np.diff(a) < 0)
(array([12]),)
或者,您也可以使用:
np.argmin(np.diff(a)) + 1
第一个索引 - 我们开始削减的位置 - 我们根据与ind_drop
对应的值得出
a[ind_drop]
array([905])
所以我们需要找到第一个元素的索引,该索引大于这个值,我们可以通过将np.argmin
应用于布尔数组来实现:
a < a[ind_drop]
array([ True, True, True, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False], dtype=bool)
np.argmin
返回数组中最小值的(第一个)索引;它适用于布尔数组,True
为1,False
为0:
np.argmin(a < a[ind_drop])
3
现在我们有两个索引,我们可以使用np.delete
删除这些索引之间的所有元素:
np.arange(ind_low, ind_drop, 1)
array([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
产生所需的输出。
答案 1 :(得分:1)
有点健壮,但应该这样做:
a = [887, 895, 903, 911, 920, 928, 936,
944, 952, 961, 969, 977, 985, 905,
914, 924, 934, 944, 954, 965, 975,
986, 996, 1007]
stop = [j for i, j in zip(a, a[1:]) if j < i][0]
drop = False
for i, e in enumerate(a):
if e > stop:
drop = i
break
if drop:
print(a[:drop] + a[a.index(stop):])
#[887, 895, 903, 905, 914, 924, 934, 944, 954, 965, 975, 986, 996, 1007]