打印每个群集的图像

时间:2019-08-12 05:06:15

标签: python machine-learning image-processing computer-vision

我使用sklearn KMeans形成图像簇,在打印每个簇的图像时遇到了困难。

  1. 我有一个尺寸为(10000,100,100,3)的np阵列列
  2. 然后我将图像弄平,以便每一行呈现一个图像。火车尺寸:(10000,30000)
  3. 我应用了KMeans。

    from scipy import ndimage
    
    from sklearn.cluster import KMeans
    
    kmeans = KMeans(n_clusters=10, random_state=0)
    
    clusters = kmeans.fit_predict(train)
    
    centers = kmeans.cluster_centers_
    

此后,我要打印每个群集的图像,

1 个答案:

答案 0 :(得分:0)

对于十个群集,您将获得十个群集中心。您现在可以打印它们或将它们可视化-这就是我想您要执行的操作。

import numpy as np
import matplotlib.pyplot as plt

#fake centers 
centers = np.random.random((10,100,100,3))

#print centers
for ci in centers:
    print(ci)

#visualize centers:
for ci in centers: 
    plt.imshow(ci)
    plt.show()

编辑:我了解到您不仅希望可视化中心,而且还希望每个集群中的其他成员可视化。

您可以对单个随机成员执行以下操作:

from scipy import ndimage
from sklearn.cluster import KMeans
import numpy as np
import matplotlib.pyplot as plt
import random

#PARAMS
n_clusters=10  

#fake train data
original_train = np.random.random((100, 100, 100, 3)) #100 images of each 100 px,py and RGB 

n,x,y,c = original_train.shape

flat_train = original_train.reshape((n,x*y*c))

kmeans = KMeans(n_clusters, random_state=0)

clusters = kmeans.fit_predict(flat_train)

centers = kmeans.cluster_centers_

#visualize centers:
for ci in centers: 
    plt.imshow(ci.reshape(x,y,c))
    plt.show()

#visualize other members
for cluster in np.arange(n_clusters):

    cluster_member_indices = np.where(clusters == cluster)[0]
    print("There are %s members in cluster %s" % (len(cluster_member_indices), cluster))

    #pick a random member
    random_member = random.choice(cluster_member_indices)
    plt.imshow(original_train[random_member,:,:,:])
    plt.show()