如何使用matplotlib绘制Kmeans文本聚类结果?

时间:2017-04-21 11:08:34

标签: python matplotlib machine-learning scikit-learn

我有以下代码用scikit learn聚集一些示例文本。

train = ["is this good?", "this is bad", "some other text here", "i am hero", "blue jeans", "red carpet", "red dog", "blue sweater", "red hat", "kitty blue"]

vect = TfidfVectorizer()
X = vect.fit_transform(train)
clf = KMeans(n_clusters=3)
clf.fit(X)
centroids = clf.cluster_centers_

plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=80, linewidths=5)
plt.show()

我无法弄清楚的是我如何绘制聚类结果。 X是csr_matrix。我想要的是每个结果的(x,y)坐标。

2 个答案:

答案 0 :(得分:2)

您的tf-idf矩阵最终为3 x 17,因此您需要进行某种投影或降维以获得二维质心。你有几个选择;这是t-SNE的一个例子:

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.manifold import TSNE

train = ["is this good?", "this is bad", "some other text here", "i am hero", "blue jeans", "red carpet", "red dog",
     "blue sweater", "red hat", "kitty blue"]

vect = TfidfVectorizer()  
X = vect.fit_transform(train)
clf = KMeans(n_clusters=3)
data = clf.fit(X)
centroids = clf.cluster_centers_

tsne_init = 'pca'  # could also be 'random'
tsne_perplexity = 20.0
tsne_early_exaggeration = 4.0
tsne_learning_rate = 1000
random_state = 1
model = TSNE(n_components=2, random_state=random_state, init=tsne_init, perplexity=tsne_perplexity,
         early_exaggeration=tsne_early_exaggeration, learning_rate=tsne_learning_rate)

transformed_centroids = model.fit_transform(centroids)
print transformed_centroids
plt.scatter(transformed_centroids[:, 0], transformed_centroids[:, 1], marker='x')
plt.show()

在您的示例中,如果您使用PCA初始化您的t-SNE,则会得到间距较大的质心;如果你使用随机初始化,你将获得微小的质心和无趣的图片。

答案 1 :(得分:1)

这是更长,更好的答案,包含更多数据:

import matplotlib.pyplot as plt
from numpy import concatenate
from sklearn.cluster import KMeans
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.manifold import TSNE

train = [
    'In 1917 a German Navy flight crashed at/near Off western Denmark with 18 aboard',
    # 'There were 18 passenger/crew fatalities',
    'In 1942 a Deutsche Lufthansa flight crashed at an unknown location with 4 aboard',
    # 'There were 4 passenger/crew fatalities',
    'In 1946 Trans Luxury Airlines flight 878 crashed at/near Moline, Illinois with 25 aboard',
    # 'There were 2 passenger/crew fatalities',
    'In 1947 a Slick Airways flight crashed at/near Hanksville, Utah with 3 aboard',
    'There were 3 passenger/crew fatalities',
    'In 1949 a Royal Canadian Air Force flight crashed at/near Near Bigstone Lake, Manitoba with 21 aboard',
    'There were 21 passenger/crew fatalities',
    'In 1952 a Airwork flight crashed at/near Off Trapani, Italy with 57 aboard',
    'There were 7 passenger/crew fatalities',
    'In 1963 a Aeroflot flight crashed at/near Near Leningrad, Russia with 52 aboard',
    'In 1966 a Alaska Coastal Airlines flight crashed at/near Near Juneau, Alaska with 9 aboard',
    'There were 9 passenger/crew fatalities',
    'In 1986 a Air Taxi flight crashed at/near Frenchglen, Oregon with 6 aboard',
    'There were 3 passenger/crew fatalities',
    'In 1989 a Air Taxi flight crashed at/near Gold Beach, Oregon with 3 aboard',
    'There were 18 passenger/crew fatalities',
    'In 1990 a Republic of China Air Force flight crashed at/near Yunlin, Taiwan with 18 aboard',
    'There were 10 passenger/crew fatalities',
    'In 1992 a Servicios Aereos Santa Ana flight crashed at/near Colorado, Bolivia with 10 aboard',
    'There were 44 passenger/crew fatalities',
    'In 1994 Royal Air Maroc flight 630 crashed at/near Near Agadir, Morocco with 44 aboard',
    'There were 10 passenger/crew fatalities',
    'In 1995 Atlantic Southeast Airlines flight 529 crashed at/near Near Carrollton, GA with 29 aboard',
    'There were 44 passenger/crew fatalities',
    'In 1998 a Lumbini Airways flight crashed at/near Near Ghorepani, Nepal with 18 aboard',
    'There were 18 passenger/crew fatalities',
    'In 2004 a Venezuelan Air Force flight crashed at/near Near Maracay, Venezuela with 25 aboard',
    'There were 25 passenger/crew fatalities',
]

vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(train)
n_clusters = 2
random_state = 1
clf = KMeans(n_clusters=n_clusters, random_state=random_state)
data = clf.fit(X)
centroids = clf.cluster_centers_
# we want to transform the rows and the centroids
everything = concatenate((X.todense(), centroids))

tsne_init = 'pca'  # could also be 'random'
tsne_perplexity = 20.0
tsne_early_exaggeration = 4.0
tsne_learning_rate = 10
model = TSNE(n_components=2, random_state=random_state, init=tsne_init,
    perplexity=tsne_perplexity,
    early_exaggeration=tsne_early_exaggeration, learning_rate=tsne_learning_rate)

transformed_everything = model.fit_transform(everything)
print(transformed_everything)
plt.scatter(transformed_everything[:-n_clusters, 0], transformed_everything[:-n_clusters, 1], marker='x')
plt.scatter(transformed_everything[-n_clusters:, 0], transformed_everything[-n_clusters:, 1], marker='o')

plt.show()

数据中有两个清晰的簇:一个是崩溃的描述,另一个是死亡的摘要。注释掉行并向上或向下调整群集大小很容易。按照编写的代码,应该显示两个蓝色的簇,一个大一点,一个小一点,两个橙色的质心。数据项多于标记项:某些数据行被转换为空间中相同的点。

two clusters 最后,较小的t-SNE学习率似乎会产生更紧密的集群。