我是机器学习的新手,我正在学习用于图像分离的k-mean,但我无法理解它的代码:
from matplotlib.image import imread
image = imread(os.path.join("images","unsupervised_learning","ladybug.png"))
image.shape
X = image.reshape(-1, 3)
kmeans = KMeans(n_clusters=8, random_state=42).fit(X)
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_img = segmented_img.reshape(image.shape)
segmented_imgs = []
n_colors = (10, 8, 6, 4, 2)
for n_clusters in n_colors:
kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_imgs.append(segmented_img.reshape(image.shape))
plt.figure(figsize=(10,5))
plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.subplot(231)
plt.imshow(image)
plt.title("Original image")
plt.axis('off')
for idx, n_clusters in enumerate(n_colors):
plt.subplot(232 + idx)
plt.imshow(segmented_imgs[idx])
plt.title("{} colors".format(n_clusters))
plt.axis('off')
plt.show()
特别是,这段代码是什么意思
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
答案 0 :(得分:0)
我建议您首先阅读无监督的学习知识,特别是使用K-means及其一般应用程序/概念进行聚类,而不仅仅是图像。
我将注释这段代码的每一行,以解释正在发生的事情。
from matplotlib.image import imread #import module
image = imread(os.path.join("images","unsupervised_learning","ladybug.png")) #Read Image
image.shape #Get shape of image, which is height, width and channel(colours)
X = image.reshape(-1, 3) #Reshaping to get color channel first
kmeans = KMeans(n_clusters=8, random_state=42).fit(X) #Applying and fitting K-means clustering
segmented_img = kmeans.cluster_centers_[kmeans.labels_] #centres of the 8 clusters made
segmented_img = segmented_img.reshape(image.shape) #reshape them using the changed image shape
segmented_imgs = []
n_colors = (10, 8, 6, 4, 2)
for n_clusters in n_colors:
kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X) #Applyting kmeans for each colour, using 10,8,6.... as number of clusters
segmented_img = kmeans.cluster_centers_[kmeans.labels_] #Repeating as mentioned above
segmented_imgs.append(segmented_img.reshape(image.shape))
plt.figure(figsize=(10,5)) #Plotting code
plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.subplot(231)
plt.imshow(image)
plt.title("Original image")
plt.axis('off')
for idx, n_clusters in enumerate(n_colors):
plt.subplot(232 + idx)
plt.imshow(segmented_imgs[idx])
plt.title("{} colors".format(n_clusters))
plt.axis('off')
plt.show()
在这一行代码中,
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
顾名思义,cluster_centers_是一个数组属性,它返回聚类中心的坐标,而labels_是一个属性,它返回每个点的Labels。 因此,segmented_img包含每个点的标签的聚类中心坐标。点。