加快numpy过滤

时间:2016-08-24 14:21:57

标签: python numpy scipy png

我正在制作音乐识别程序,作为其中的一部分,我需要从png(2200x1700像素)中找到numpy数组的最大连通区域。我目前的解决方案如下。

labels, nlabels = ndimage.label(blobs)
cutoff = len(blobs)*len(blobs[0]) / nlabels
blobs_found = 0
x = []
t1 = time()
for n in range(1, nlabels+1):
    squares = np.where(labels==n)
    if len(squares[0]) < cutoff:
        blobs[squares] = 0
    else:
        blobs_found += 1
        blobs[squares] = blobs_found
        x.append(squares - np.amin(squares, axis=0, keepdims=True))
nlabels = blobs_found
print(time() - t1)

这样可行,但运行需要约6.5秒。有没有办法可以从这段代码中删除循环(或以其他方式加速)?

2 个答案:

答案 0 :(得分:2)

您可以使用以下符号获取每个标记区域的大小(以像素为单位):

unique_labels = numpy.unique(labels)
label_sizes = scipy.ndimage.measurement.sum(numpy.ones_like(blobs), labels, unique_labels)

最大的将是:

unique_labels[label_size == numpy.max(label_size)]

答案 1 :(得分:2)

最快可能是使用numpy.bincount并从那里开始工作。类似的东西:

labels, nlabels = ndimage.label(blobs)
cutoff = len(blobs)*len(blobs[0]) / float(nlabels)

label_counts = np.bincount(labels)

# Re-label, taking the cutoff into account
cutoff_mask = (label_counts >= cutoff)
cutoff_mask[0] = False
label_mapping = np.zeros_like(label_counts)
label_mapping[cutoff_mask] = np.arange(cutoff_mask.sum()) + 1

# Create an image-array with the updated labels
blobs = label_mapping[labels].astype(blobs.dtype)

这可以针对速度进行优化,但我的目标是可读性。