在Pyspark RDD中优化Kmeans聚类代码

时间:2018-04-22 18:19:52

标签: dictionary optimization pyspark rdd k-means

这是一个非常漫长的问题。这是我的kmeans聚类代码。对于10个集群,我花了一个多小时来计算结果。

我可以使用RDD(map,flatmap,filter)等更快地使用for和while循环吗?如果是这样,我应该在哪里进行更改。

这里bj是一个包含30,000多个元素的RDD:这里使用的RDD是一个字典集合。我们可以使用键“名称”识别每个元素,并且所有其他键具有与其关联的值[1,0],以计算欧氏距离

Rdd内容:

bj.collect():

[{'name': 'ab',
  'abc': 0,
  'def': 0,
  'ghi': 0,
  'jkl': 0,
  ....},
 {'name': 'ak',
  'abc': 1,
  'def': 1,
  'ghi': 0,
  'jkl': 0,
  ....},...]



from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql import Row
import pyspark
import sys
from sys import argv
sc = SparkContext('local')
spark = SparkSession(sc)



all_states =["ab", "ak", "ar", "az", "ca", "co", "ct", "de", "dc", "fl",
         "ga", "hi", "id", "il", "in", "ia", "ks", "ky", "la", "me", "md",
         "ma", "mi", "mn", "ms", "mo", "mt", "ne", "nv", "nh", "nj", "nm",
         "ny", "nc", "nd", "oh", "ok", "or", "pa", "pr", "ri", "sc", "sd",
         "tn", "tx", "ut", "vt", "va", "vi", "wa", "wv", "wi", "wy", "al",
         "bc", "mb", "nb", "lb", "nf", "nt", "ns", "nu", "on", "qc", "sk",
         "yt", "dengl", "fraspm" ]


import random
random.seed(123)

init_states=random.sample(all_states, 3)

import math
centroids = bj.filter(lambda x: x['name'] in init_states)
ro=2
mon=0
mp=6
while(ro!=3):

#updating centroids after 1st cluster

    if(mp!=6):
        R=[]
        for i in range(len(jo)):
            j=jo[i][0]
            dc={}
            d1=bj.filter(lambda x: x['name']==j).first()

            for r in jo[i]:
                if(j!=r):
                    d2=bj.filter(lambda x: x['name']==r).first()
                    dc={key: d1[key] + d2[key] for key in d1.keys() if key not in {'name'} }
                    d1=dc
                if(j==r)&(len(jo[i])==1):
                    d2=bj.filter(lambda x: x['name']==r).first()
                    dc=d2
            dc={key: dc[key]/len(jo[i]) for key in dc.keys() if key not in {'name'} } 
            R.append(dc)
            R[i]['name']=init_states[i]


        centroids=sc.parallelize(R)

 #Clustering
    mp=3
    cent=centroids
    s={}
    r={}
    S=[]
    for i in all_states:
        mn=300000

        dc={}
        d1=bj.filter(lambda x: x['name']==i).first()
        for j in init_states:
            d2=centroids.filter(lambda x: x['name']==j).first()
            dc={key: (d1[key] - d2[key])**2 for key in d1.keys() if key not in 'name'}
            val=sum([v for v in dc.values()])
            val=math.sqrt(val)

            if(val<mn):
                s[i]=j
                mn=val

    k=[]
    for i in init_states:
        l=[]
        for j in all_states:
            if(s[j]==i):
                c=j
                l.append(c)

        k.append(sorted(l))
    jo=sorted(k)

    mon=mon+1
    if(mon>1):
        if(mu==jo):
            ro=3
    mu=jo


#print Cluster
lp=3;
for j in range(len(mu)):
    if(lp==3):
        print ("* Class",j)
    else:
        print("\n* Class",j)

    for i in mu[j]:
    lp=2
        print(i,'',end="")

print('')

输出:

* Class 0
ab bc mb mn qc sk 
* Class 1
ak dengl lb nt nu yt 
* Class 2
al fl ga ms nc sc 
* Class 3
ar ks ky la mo ok tn tx 
* Class 4
az ia nd ne sd 
* Class 5
ca co id mt nm nv or ut wa wy 
* Class 6
ct dc de il in ma md me mi nh nj ny oh on pa ri va vt wi wv 
* Class 7
fraspm nb nf ns 
* Class 8
hi 
* Class 9
pr vi 

0 个答案:

没有答案