这是一个非常漫长的问题。这是我的kmeans聚类代码。对于10个集群,我花了一个多小时来计算结果。
我可以使用RDD(map,flatmap,filter)等更快地使用for和while循环吗?如果是这样,我应该在哪里进行更改。
这里bj是一个包含30,000多个元素的RDD:这里使用的RDD是一个字典集合。我们可以使用键“名称”识别每个元素,并且所有其他键具有与其关联的值[1,0],以计算欧氏距离
Rdd内容:
bj.collect():
[{'name': 'ab',
'abc': 0,
'def': 0,
'ghi': 0,
'jkl': 0,
....},
{'name': 'ak',
'abc': 1,
'def': 1,
'ghi': 0,
'jkl': 0,
....},...]
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql import Row
import pyspark
import sys
from sys import argv
sc = SparkContext('local')
spark = SparkSession(sc)
all_states =["ab", "ak", "ar", "az", "ca", "co", "ct", "de", "dc", "fl",
"ga", "hi", "id", "il", "in", "ia", "ks", "ky", "la", "me", "md",
"ma", "mi", "mn", "ms", "mo", "mt", "ne", "nv", "nh", "nj", "nm",
"ny", "nc", "nd", "oh", "ok", "or", "pa", "pr", "ri", "sc", "sd",
"tn", "tx", "ut", "vt", "va", "vi", "wa", "wv", "wi", "wy", "al",
"bc", "mb", "nb", "lb", "nf", "nt", "ns", "nu", "on", "qc", "sk",
"yt", "dengl", "fraspm" ]
import random
random.seed(123)
init_states=random.sample(all_states, 3)
import math
centroids = bj.filter(lambda x: x['name'] in init_states)
ro=2
mon=0
mp=6
while(ro!=3):
#updating centroids after 1st cluster
if(mp!=6):
R=[]
for i in range(len(jo)):
j=jo[i][0]
dc={}
d1=bj.filter(lambda x: x['name']==j).first()
for r in jo[i]:
if(j!=r):
d2=bj.filter(lambda x: x['name']==r).first()
dc={key: d1[key] + d2[key] for key in d1.keys() if key not in {'name'} }
d1=dc
if(j==r)&(len(jo[i])==1):
d2=bj.filter(lambda x: x['name']==r).first()
dc=d2
dc={key: dc[key]/len(jo[i]) for key in dc.keys() if key not in {'name'} }
R.append(dc)
R[i]['name']=init_states[i]
centroids=sc.parallelize(R)
#Clustering
mp=3
cent=centroids
s={}
r={}
S=[]
for i in all_states:
mn=300000
dc={}
d1=bj.filter(lambda x: x['name']==i).first()
for j in init_states:
d2=centroids.filter(lambda x: x['name']==j).first()
dc={key: (d1[key] - d2[key])**2 for key in d1.keys() if key not in 'name'}
val=sum([v for v in dc.values()])
val=math.sqrt(val)
if(val<mn):
s[i]=j
mn=val
k=[]
for i in init_states:
l=[]
for j in all_states:
if(s[j]==i):
c=j
l.append(c)
k.append(sorted(l))
jo=sorted(k)
mon=mon+1
if(mon>1):
if(mu==jo):
ro=3
mu=jo
#print Cluster
lp=3;
for j in range(len(mu)):
if(lp==3):
print ("* Class",j)
else:
print("\n* Class",j)
for i in mu[j]:
lp=2
print(i,'',end="")
print('')
输出:
* Class 0
ab bc mb mn qc sk
* Class 1
ak dengl lb nt nu yt
* Class 2
al fl ga ms nc sc
* Class 3
ar ks ky la mo ok tn tx
* Class 4
az ia nd ne sd
* Class 5
ca co id mt nm nv or ut wa wy
* Class 6
ct dc de il in ma md me mi nh nj ny oh on pa ri va vt wi wv
* Class 7
fraspm nb nf ns
* Class 8
hi
* Class 9
pr vi