我有以下代码,用于根据分割步骤中创建的标签图像中的蒙版,用中值(或任何其他函数)值替换图像的每个值。感觉好像for循环可以被矢量化。这样做的最佳方法是什么?
我研究了为每个标签构建一个单独的索引数组,但最终没有看到它会如何帮助。
import numpy as np
from skimage.segmentation import slic
from skimage import data, io
def create_segment_image(original_image, labels_image):
segment_image = np.zeros(original_image.shape, original_image.dtype)
for label in np.unique(labels_image):
segment_image[labels_image==label] = np.median(original_image[labels_image==label], axis=0)
return segment_image
if __name__ == '__main__':
original_image = data.astronaut()
labels_image = slic(original_image, n_segments=1000, max_iter=10, compactness=7, sigma=1, convert2lab=True, enforce_connectivity=True, min_size_factor=0.1, max_size_factor=3, slic_zero=False)
segment_image = create_segment_image(original_image, labels_image)
# io.imsave('images/segment_image.png', segment_image)
答案 0 :(得分:1)
我不知道矢量化最里面的循环。对median
的调用每次计算不同数量的元素,这使得很难将所有调用放入单个数组中。
另一方面,就如何按标签选择元素而言,有一些相当低的成果。您可以在原始函数中找到每个标签的索引两次,只计算一次运行时约25%的索引数组
def create_segment_image_2(original_image, labels_image):
segment_image = np.zeros(original_image.shape, original_image.dtype)
for label in np.unique(labels_image):
inds = np.where(labels_image == label)
segment_image[inds] = np.median(original_image[inds], axis=0)
return segment_image
通过按标签对数组索引进行排序,然后利用该排序选择图像元素到中位数,可以获得更大的改进。用一种方式替换多次搜索可以提高20倍的速度。
def create_segment_image_3(original_image, labels_image):
segment_image = np.zeros(original_image.shape, original_image.dtype)
# sort the indices by their labels
labelinds = np.argsort(labels_image, None)
labels = np.unique(labels_image)
# use the searchsorted to find the indices for each label
rights = np.searchsorted(labels_image.flatten(), labels, side='right', sorter=labelinds)
left = 0
for right in rights:
# choose our block of the image array
inds = labelinds[left:right]
# convert back to a two dimensional index array
inds = [inds // original_image.shape[1], inds % original_image.shape[1]]
segment_image[inds] = np.median(original_image[inds], axis=0)
# update our boundaries
left = right
return segment_image
ipython中的基准测试
In [54]: %timeit create_segment_image(original_image, labels_image)
2.15 s ± 29.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [55]: %timeit create_segment_image_2(original_image, labels_image)
1.48 s ± 4.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [56]: %timeit create_segment_image_3(original_image, labels_image)
121 ms ± 561 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
确认我们的新解决方案与旧解决方案的结果相同。
In [57]: np.all(create_segment_image_2(original_image, labels_image) == create_segment_image(original_image, labels_image))
Out[57]: True
In [58]: np.all(create_segment_image_3(original_image, labels_image) == create_segment_image(original_image, labels_image))
Out[58]: True