计算质心和准确性

时间:2017-06-01 08:53:33

标签: python deep-learning k-means

我从暹罗网络获得了两分feat_left, feat_right,我在x,y坐标中绘制了这些点,如下所示。enter image description here

这是python脚本

import json
import matplotlib.pyplot as plt
import numpy as np



data = json.load(open('predictions-mnist.txt'))

n=len(data['outputs'].items())
label_list = np.array(range(n))
feat_left = np.random.random((n,2))


count=1

for key,val in data['outputs'].items():
    feat = data['outputs'][key]['feat_left']
    feat_left[count-1] = feat
    key = key.split("/")
    key = int(key[6])
    label_list[count - 1] = key
    count = count + 1


f = plt.figure(figsize=(16,9))

c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
     '#ff00ff', '#990000', '#999900', '#009900', '#009999']

for i in range(10):
    plt.plot(feat_left[label_list==i,0].flatten(), feat_left[label_list==i,1].flatten(), '.', c=c[i])
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
plt.grid()
plt.show()

现在我想计算每个群集的 centriod 然后 pure

1 个答案:

答案 0 :(得分:2)

质心只是mean

mvn '-P!with:module2' package

至于准确度,您可以计算质心的均方误差(距离):

centorids = np.zeros((10,2), dtype='f4')
for i in xrange(10):
    centroids[i,:] = np.mean( feat_left[label_list==i, :2], axis=0 )

计算purity

sqerr = np.zeros((10,), dtype='f4')
for i in xrange(10):
    sqerr[i] = np.sum( (feat_left[label_list==i, :2]-centroids[i,:])**2 )

有关在群集中选择最常用标签的详细信息,请参阅this answer