如何从numpy数组中删除某个范围?

时间:2018-01-26 12:38:57

标签: python arrays numpy

我有以下numpy数组:

import numpy as np

a = np.asarray([887, 895, 903, 911, 920, 928, 936,  
944, 952, 961, 969, 977, 985, 905, 
914, 924, 934, 944, 954, 965, 975, 986, 996, 1007])

值不断增加,直至985,然后降至905。从这里开始,价值再次开始增加。

我需要一个检测此丢弃的函数,并删除所有大于它所删除值的数组元素,使剩余值仍然单调递增(要删除的值为粗体):

[887,895,903, 911,920,928,936,
944,952,961,969,977,985,
905, 914,924,934,944,954,965,975,986,996,1007]

期望的结果如下:

[887, 895, 903, 905, 
914, 924, 934, 944, 954, 965, 975, 986, 996, 1007]

我怎么能这样做?

2 个答案:

答案 0 :(得分:1)

可能有更优雅的解决方案,但这似乎有效:

# get index where you observe the drop
ind_drop = np.where(np.diff(a) < 0)[0] + 1  # or np.argmin(np.diff(a)) + 1

# get index from start of the range which should be deleted
ind_low = np.argmin(a < a[ind_drop])

# delete the requested range
a_new = np.delete(a, np.arange(ind_low, ind_drop, 1))

产生

array([  887.,   895.,   903.,   905.,   914.,   924.,   934.,   944.,
         954.,   965.,   975.,   986.,   996.,  1007.])

一些解释:

必须找到应该切割数组的索引。第二个指数ind_drop在那里我们观察到下降,即两个元素之间的差异首次变为负值。

np.diff(a)
array([  8,   8,   8,   9,   8,   8,   8,   8,   9,   8,   8,   8, -80,
         9,  10,  10,  10,  10,  11,  10,  11,  10,  11])

我们可以使用布尔数组

来获取此索引
np.diff(a) < 0
array([False, False, False, False, False, False, False, False, False,
       False, False, False,  True, False, False, False, False, False,
       False, False, False, False, False], dtype=bool)

并应用np.where

np.where(np.diff(a) < 0)
(array([12]),)

或者,您也可以使用:

np.argmin(np.diff(a)) + 1

第一个索引 - 我们开始削减的位置 - 我们根据与ind_drop对应的值得出

a[ind_drop]
array([905])

所以我们需要找到第一个元素的索引,该索引大于这个值,我们可以通过将np.argmin应用于布尔数组来实现:

a < a[ind_drop]
array([ True,  True,  True, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False], dtype=bool)

np.argmin返回数组中最小值的(第一个)索引;它适用于布尔数组,True为1,False为0:

np.argmin(a < a[ind_drop])
3

现在我们有两个索引,我们可以使用np.delete删除这些索引之间的所有元素:

np.arange(ind_low, ind_drop, 1)
array([ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

产生所需的输出。

答案 1 :(得分:1)

有点健壮,但应该这样做:

a = [887, 895, 903, 911, 920, 928, 936,  
     944, 952, 961, 969, 977, 985, 905, 
     914, 924, 934, 944, 954, 965, 975,
     986, 996, 1007]

stop = [j for i, j in zip(a, a[1:]) if j < i][0]
drop = False

for i, e in enumerate(a):
    if e > stop:
        drop = i
        break

if drop:
    print(a[:drop] + a[a.index(stop):])

#[887, 895, 903, 905, 914, 924, 934, 944, 954, 965, 975, 986, 996, 1007]