我有一个1和0的Numpy一维数组。例如
a = np.array([0,1,1,1,0,0,0,0,0,0,0,1,0,1,1,0,0,0,1,1,0,0])
如果连续0的长度小于阈值,我想将连续0替换为1,让2表示,并排除第一个和最后一个连续0。所以它会输出一个像这样的新数组
out: [0,1,1,1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,1,1,0,0]
如果阈值为4,则输出为
out: [0,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,0,0]
我所做的是计算每个细分市场'长度 我从this answer
得到了这个解决方案segLengs = np.diff(np.flatnonzero(np.concatenate(([True], a[1:]!= a[:-1], [True] ))))
out: [1,3,7,1,1,2,3,2,2]
然后找到小于阈值的段
gaps = np.where(segLengs <= threshold)[0]
gapsNeedPadding = gaps[gaps % 2 == 0]
然后循环gapsNeedPadding
数组
同样itertools.groupby
可以完成这项工作,但会有点慢
是否有更有效的解决方案?我更喜欢矢量化解决方案。速度是我所需要的。我已经有了一个缓慢的解决方案,通过数组循环
更新
在this question尝试了来自@divakar的解决方案,但是当阈值较大时,它似乎无法解决我的问题。
numpy_binary_closing
和binary_closing
有不同的输出。这两个函数都不会从边界+阈值
我是否在以下代码中犯了错误?
import numpy as np
from scipy.ndimage import binary_closing
def numpy_binary_closing(mask,threshold):
# Define kernel
K = np.ones(threshold)
# Perform dilation and threshold at 1
dil = np.convolve(mask, K, mode='same') >= 1
# Perform erosion on the dilated mask array and threshold at given threshold
dil_erd = np.convolve(dil, K, mode='same') >= threshold
return dil_erd
threshold = 4
mask = np.random.rand(100) > 0.5
print(mask.astype(int))
out1 = numpy_binary_closing(mask, threshold)
out2 = binary_closing(mask, structure=np.ones(threshold))
print(out1.astype(int))
print(out2.astype(int))
print(np.allclose(out1,out2))
Outout
[0 1 1 0 1 1 0 0 0 1 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 1 1 1 1 0 0 0 1 1 0 0 0 1 1 0 1 0 1 0 0 0 0 1 0 0 1 0 1 1 1 1 1 1 0 1 0 0 0 1 0 1 0 0 0 1 1 1 0 1 1 0 1 1 1 1 0 1 1 1 0 0 0 1 0 0 0 0 1 0 1 1 1]
[0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 0]
[0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 0]
False
答案 0 :(得分:1)
没有更好的想法:
for _ in range(threshold - 1):
a |= np.roll(a, 1)
(此代码不处理尾随零。)