我有一个相当大的3维numpy(2000,2500,32)数组,我需要操作。有些行很糟糕,所以我需要删除几行。 为了检测哪一行是坏的"我使用以下功能
def badDetect(x):
for i in xrange(10,19):
ptp = np.ptp(x[i*100:(i+1)*100])
if ptp < 0.01:
return True
return False
,其标记为2000的任何序列,其具有100个值的范围,峰值到峰值小于0.01。 在这种情况下,我想删除2000个值的序列(可以从[:,x,y]中选择numpy) Numpy删除似乎是接受索引,但仅适用于二维数组。
答案 0 :(得分:0)
你肯定要重塑你的输入数组,因为切掉&#34;行&#34;从3D立方体中留下一个无法正确处理的结构。
由于我们没有您的数据,我将首先使用不同的示例来解释这种可能的解决方案的工作原理:
>>> import numpy as np
>>> from numpy.lib.stride_tricks import as_strided
>>>
>>> threshold = 18
>>> a = np.arange(5*3*2).reshape(5,3,2) # your dataset of 2000x2500x32
>>> # Taint the data:
... a[0,0,0] = 5
>>> a[a==22]=20
>>> print(a)
[[[ 5 1]
[ 2 3]
[ 4 5]]
[[ 6 7]
[ 8 9]
[10 11]]
[[12 13]
[14 15]
[16 17]]
[[18 19]
[20 21]
[20 23]]
[[24 25]
[26 27]
[28 29]]]
>>> a2 = a.reshape(-1, np.prod(a.shape[1:]))
>>> print(a2) # Will prove to be much easier to work with!
[[ 5 1 2 3 4 5]
[ 6 7 8 9 10 11]
[12 13 14 15 16 17]
[18 19 20 21 20 23]
[24 25 26 27 28 29]]
正如您所看到的,从上面的表示中,现在已经变得更加清晰,您想要计算峰值到峰值的窗口。如果你要删除&#34;行,那么你需要这个表格。 (现在它们已被转换为列)来自这个数据结构,这是你无法在3个方面做的事情!
>>> isize = a.itemsize # More generic, in case you have another dtype
>>> slice_size = 4 # How big each continuous slice is over which the Peak2Peak value is calculated
>>> slices = as_strided(a2,
... shape=(a2.shape[0] + 1 - slice_size, slice_size, a2.shape[1]),
... strides=(isize*a2.shape[1], isize*a2.shape[1], isize))
>>> print(slices)
[[[ 5 1 2 3 4 5]
[ 6 7 8 9 10 11]
[12 13 14 15 16 17]
[18 19 20 21 20 23]]
[[ 6 7 8 9 10 11]
[12 13 14 15 16 17]
[18 19 20 21 20 23]
[24 25 26 27 28 29]]]
因此,我采用了4个元素的窗口大小:如果这4个元素切片中的任何一个中的峰峰值(每个数据集,因此每列)小于某个阈值,我想排除它。这可以这样做:
>>> mask = np.all(slices.ptp(axis=1) >= threshold, axis=0) # These are the ones that are of interest
>>> print(a2[:,mask])
[[ 1 2 3 5]
[ 7 8 9 11]
[13 14 15 17]
[19 20 21 23]
[25 26 27 29]]
您现在可以清楚地看到已删除的污染数据。但请记住,你不能简单地从3D阵列中删除这些数据(但你可以将其掩盖掉)。
显然,您必须在用例中将threshold
设置为.01
,并将slice_size
设置为100
。
请注意,虽然as_strided
表单的内存效率极高,但计算此数组的峰峰值并存储该结果确实需要大量内存:1901x(2500x32)案例场景,所以当你不忽略前1000个切片。在您的情况下,您只对1000:1900
中的切片感兴趣,您必须将其添加到代码中,如下所示:
mask = np.all(slices[1000:1900,:,:].ptp(axis=1) >= threshold, axis=0)
这会减少将此掩码存储到&#34;仅限&#34; 900x(2500x32)值(无论您使用何种数据类型)。