使用MNIST数据的Kmeans算法

时间:2019-05-10 21:17:51

标签: tensorflow jupyter-notebook k-means mnist

我使用了使用tensorflow的kmeans算法来获取图像(假设为数字7)。但是,我得到了完全不同的图像(我得到了9号)。你们可以告诉我我做错了吗?

import tensorflow as tf
# get the MNIST Dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
type(mnist)

def kmeans(X, k, max_iter=10, rand_seed=0):
    np.random.seed(rand_seed)
    Mu = X[np.random.choice(X.shape[0],k),:]
    for i in range(max_iter):
        D = -2*X@Mu.T + (X**2).sum(axis=1)[:,None] + (Mu**2).sum(axis=1)
        y = np.argmin(D,axis=1)
        Mu = np.array([np.mean(X[y==i],axis=0) for i in range(k)])
    loss = np.linalg.norm(X - Mu[np.argmin(D,axis=1),:])**2/X.shape[0]
    return Mu, y, loss

## BEGIN SOLUTION

# start with k = 50
k = 50
# flatten each image to be 784 value vector and form 55000 x 784 matrix X
X = mnist.train.images

# Choose k random centers from X. That is k random data points become the original cluster centers
Mu = X[np.random.choice(X.shape[0],k),:]
# call the kmeans algorithm defined above. Be sure to have the proper arguments passed to kmeans.
[Mu, y, loss] = kmeans(X, k, max_iter=10, rand_seed=0)

# Plot the cluster centers as images
plt1 = plt.imshow(Mu[2].reshape(28,28))

## END SOLUTION

0 个答案:

没有答案