通过k-mean python进行图像分离

时间:2018-07-09 12:42:30

标签: python image image-processing machine-learning k-means

我是机器学习的新手,我正在学习用于图像分离的k-mean,但我无法理解它的代码:

from matplotlib.image import imread
image = imread(os.path.join("images","unsupervised_learning","ladybug.png"))
image.shape
X = image.reshape(-1, 3)
kmeans = KMeans(n_clusters=8, random_state=42).fit(X)
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_img = segmented_img.reshape(image.shape)
segmented_imgs = []
n_colors = (10, 8, 6, 4, 2)
for n_clusters in n_colors:
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)
    segmented_img = kmeans.cluster_centers_[kmeans.labels_]
    segmented_imgs.append(segmented_img.reshape(image.shape))
plt.figure(figsize=(10,5))
plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.subplot(231)
plt.imshow(image)
plt.title("Original image")
plt.axis('off')
for idx, n_clusters in enumerate(n_colors):
   plt.subplot(232 + idx)
   plt.imshow(segmented_imgs[idx])
   plt.title("{} colors".format(n_clusters))
   plt.axis('off')
plt.show()

使用的图像: enter image description here 输出图 enter image description here

特别是,这段代码是什么意思

segmented_img = kmeans.cluster_centers_[kmeans.labels_]

1 个答案:

答案 0 :(得分:0)

我建议您首先阅读无监督的学习知识,特别是使用K-means及其一般应用程序/概念进行聚类,而不仅仅是图像。

我将注释这段代码的每一行,以解释正在发生的事情。

from matplotlib.image import imread #import module
image = imread(os.path.join("images","unsupervised_learning","ladybug.png")) #Read Image
image.shape #Get shape of image, which is height, width and channel(colours)
X = image.reshape(-1, 3) #Reshaping to get color channel first
kmeans = KMeans(n_clusters=8, random_state=42).fit(X) #Applying and fitting K-means clustering
segmented_img = kmeans.cluster_centers_[kmeans.labels_] #centres of the 8 clusters made
segmented_img = segmented_img.reshape(image.shape) #reshape them using the changed image shape
segmented_imgs = []
n_colors = (10, 8, 6, 4, 2) 
for n_clusters in n_colors:
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X) #Applyting kmeans for each colour, using 10,8,6.... as number of clusters
    segmented_img = kmeans.cluster_centers_[kmeans.labels_] #Repeating as mentioned above
    segmented_imgs.append(segmented_img.reshape(image.shape))
plt.figure(figsize=(10,5)) #Plotting code
plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.subplot(231)
plt.imshow(image)
plt.title("Original image")
plt.axis('off')
for idx, n_clusters in enumerate(n_colors):
   plt.subplot(232 + idx)
   plt.imshow(segmented_imgs[idx])
   plt.title("{} colors".format(n_clusters))
   plt.axis('off')
plt.show()

在这一行代码中,

segmented_img = kmeans.cluster_centers_[kmeans.labels_]

顾名思义,cluster_centers_是一个数组属性,它返回聚类中心的坐标,而labels_是一个属性,它返回每个点的Labels。 因此,segmented_img包含每个点的标签的聚类中心坐标。点。