我正在R中实现kmeans算法,但是我遇到了可怕的性能问题。 我来自python java和C ++,所以我不太习惯用R编写代码,所以我想知道我是否可以就执行的基本操作获得建议。
首先是我获取两点之间距离的功能:
distance <- function(pt1, pt2){
pt1 <- pt1[0:NUMBER_OF_FEATURES]
pt2 <- pt2[0:NUMBER_OF_FEATURES]
pt2 <- t(pt2)
sum <- 0
counter <- 1
for (i in 1:nrow(pt2)){
sum <- sum + ((pt1[counter] - pt2[counter])^2)
counter <- counter + 1
}
value <- sqrt(sum)
return(value)
}
从我的理解看来,我看起来并不能做得更好,但是我知道我不应该真正在R中使用for循环。
我还有另一个功能,着重于更新每个群集的质心,我这样编码它:
update_centroids <- function(ptlst, centroids){
centroids <- matrix(, nrow = NUMBER_OF_CLUSTERS, ncol = NUMBER_OF_FEATURES)
for (i in 1:NUMBER_OF_CLUSTERS){
temp <- ptlst[which(ptlst$cluster == i),]
temp <- temp[0:NUMBER_OF_FEATURES]
print(ncol(temp))
centroid <- c()
for (j in 1:ncol(temp)){
centroid <- c(centroid, mean(as.numeric(unlist(temp[j]))))
}
print(centroid)
centroids[i,] <- centroid
}
print(centroids)
}
再次,据我了解,我实际上不应该像这样编写此部分,而应该使用通用的编写方法,这样做会更快。
总体而言,我的完整算法在虹膜数据集上运行的时间为2.24秒,而我在python中的实现在0.03秒内运行
所以我在这里显然做错了,有些事情需要花费大量时间,但我无法亲自上手
预先感谢您的回答, Shraneid
答案 0 :(得分:3)
distance <- function(pt1, pt2){
pt1 <- pt1[1:NUMBER_OF_FEATURES]
pt2 <- pt2[1:NUMBER_OF_FEATURES]
x <- sum((pt1 - pt2)^2)
value <- sqrt(x)
return(value)
}
对于第二个功能,您正在循环内增长对象,这在R中很慢。
我想您的数据如下:
NUMBER_OF_CLUSTERS <- 2
NUMBER_OF_FEATURES <- 4
n <- 100
set.seed(13)
ptlst <- data.frame(cluster = sample.int(NUMBER_OF_CLUSTERS, n, replace = T),
replicate(NUMBER_OF_FEATURES, rnorm(n)))
head(ptlst)
# cluster X1 X2 X3 X4
# 1 2 0.2731292 -2.84476384 0.6137843 2.10781521
# 2 1 0.7555251 1.71457759 0.4126145 1.57738122
# 3 1 -0.3490184 -1.22881682 -0.4588937 0.06149504
# 4 1 -0.5461908 -0.31407296 -0.6731785 -0.23792899
# 5 2 0.2343620 -0.06991232 0.1930543 -0.17730688
# 6 1 -0.2978282 -0.83760143 1.3829291 -1.17393025
所以我们可以尝试:
update_centroids <- function(ptlst){
t(sapply(1:NUMBER_OF_CLUSTERS, function(i) {
temp <- ptlst[which(ptlst$cluster == i),]
colMeans(temp)
}))
}
update_centroids(ptlst)
# cluster X1 X2 X3 X4
# [1,] 1 0.07365732 -0.0725119 -0.08745870 0.03406371
# [2,] 2 -0.24100628 -0.1044056 0.09288702 0.40949754
或使用data.table
require(data.table)
x <- as.data.table(ptlst)
x[, lapply(.SD, mean), keyby = cluster]
# cluster X1 X2 X3 X4
# 1: 1 0.07365732 -0.0725119 -0.08745870 0.03406371
# 2: 2 -0.24100628 -0.1044056 0.09288702 0.40949754
我建议您先阅读有关R的一些指南:
https://r4ds.had.co.nz/introduction.html https://cran.r-project.org/web/packages/data.table/vignettes/datatable-intro.html
等
在线上有很多有用的材料。