从numpy数组

时间:2017-03-10 13:19:57

标签: numpy

问题

我有一个1维的numpy数组,主要用零填充,但也包含一些非零值组。

>> import numpy as np
>> a = np.zeros(10)
>> a[2:4] = 2
>> a[6:9] = 3
>> print a
[ 0.  0.  2.  2.  0.  0.  3.  3.  3.  0.]

我想获得仅包含最后一个非零组的数组。换句话说,除了最后一个非零组之外的所有组都应该用零替换。 (这些组可能只有1个元素长)。像这样:

[ 0.  0.  0.  0.  0.  0.  3.  3.  3.  0.]

非强健解决方案

这似乎可以解决问题。反转数组并找到元素之间的变化为负的第一个索引。然后用零替换所有后续元素。然后翻回来。这有点啰嗦:

>> b = a[::-1]
>> b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
>> c = b[::-1]
>> print c
[ 0.  0.  0.  0.  0.  0.  3.  3.  3.  0.]

针对特定案例失败

但是,它不健壮并且在以下情况下失败(因为where命令返回一个空的索引列表):

>> a = np.zeros(10)
>> a[0:4] = 2
>> print a
[ 2.  2.  2.  2.  0.  0.  0.  0.  0.  0.]
>> b = a[::-1]
>> b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
>> c = b[::-1]
>> print c

Traceback (most recent call last):

  File "<ipython-input-81-8cba57558ba8>", line 1, in <module>
    runfile('C:/Users/name/test1.py', wdir='C:/Users/name')

  File "C:\ProgramData\Anaconda2\lib\site-packages\spyder\utils\site\sitecustomize.py", line 866, in runfile
    execfile(filename, namespace)

  File "C:\ProgramData\Anaconda2\lib\site-packages\spyder\utils\site\sitecustomize.py", line 87, in execfile
    exec(compile(scripttext, filename, 'exec'), glob, loc)

  File "C:/Users/name/test1.py", line 21, in <module>
    b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0

IndexError: index 0 is out of bounds for axis 0 with size 0

修正

所以我需要引入一个if子句:

>> b = a[::-1]
>> if len(np.where(np.ediff1d(b) < 0)[0]) > 0:
>>     b[np.where(np.ediff1d(b) < 0)[0][0] + 1:] = 0
>> c = b[::-1]
>> print c
[ 2.  2.  2.  2.  0.  0.  0.  0.  0.  0.]

有更优雅的方式吗?

更新 继Divakar的优秀答案和mtrw的问题之后,我想扩展规范。如果输入数组的非零值为负值,则对于在分组内发生变化的非零数字组,该方法也应该有效。

e.g。 np.array([1, 0, 0, 4, 5, 4, 5, 0, 0])

这意味着我们检查元素之间的正面或负面差异的方法,以便找到组边界,这样做不会很好。

1 个答案:

答案 0 :(得分:2)

方法#1

既然我们追求的是优雅,那就让我们自己的吧 -

a[:(a[1:] > a[:-1]).cumsum().argmax()] = 0

示例运行 -

In [605]: a
Out[605]: array([ 0.,  0.,  2.,  2.,  0.,  0.,  3.,  3.,  3.,  0.])

In [606]: a[:(a[1:] > a[:-1]).cumsum().argmax()] = 0

In [607]: a
Out[607]: array([ 0.,  0.,  0.,  0.,  0.,  0.,  3.,  3.,  3.,  0.])

方法#2

以上方法假设最后一个组编号大于0&#39; s。如果情况并非如此,并且对于非零组可能具有不同数字的情况,请让提供一行以获得通用解决方案 -

mask = a != 0
a[:(mask[1:] > mask[:-1]).cumsum().argmax()] = 0

示例运行 -

In [667]: a
Out[667]: array([-1,  0,  0, -4, -5,  4, -5,  0,  0])

In [668]: mask = a != 0

In [669]: a[:(mask[1:] > mask[:-1]).cumsum().argmax()] = 0

In [670]: a
Out[670]: array([ 0,  0,  0, -4, -5,  4, -5,  0,  0])