基本R的性能问题

时间:2018-11-28 11:40:43

标签: r performance

我正在R中实现kmeans算法,但是我遇到了可怕的性能问题。 我来自python java和C ++,所以我不太习惯用R编写代码,所以我想知道我是否可以就执行的基本操作获得建议。

首先是我获取两点之间距离的功能:

distance <- function(pt1, pt2){
  pt1 <- pt1[0:NUMBER_OF_FEATURES]
  pt2 <- pt2[0:NUMBER_OF_FEATURES]

  pt2 <- t(pt2)
  sum <- 0
  counter <- 1
  for (i in 1:nrow(pt2)){
    sum <- sum + ((pt1[counter] - pt2[counter])^2)
    counter <- counter + 1
  }
  value <- sqrt(sum)
  return(value)
} 

从我的理解看来,我看起来并不能做得更好,但是我知道我不应该真正在R中使用for循环。

我还有另一个功能,着重于更新每个群集的质心,我这样编码它:

update_centroids <- function(ptlst, centroids){
  centroids <- matrix(, nrow = NUMBER_OF_CLUSTERS, ncol = NUMBER_OF_FEATURES)

  for (i in 1:NUMBER_OF_CLUSTERS){
    temp <- ptlst[which(ptlst$cluster == i),]
    temp <- temp[0:NUMBER_OF_FEATURES]
    print(ncol(temp))
    centroid <- c()
    for (j in 1:ncol(temp)){
      centroid <- c(centroid, mean(as.numeric(unlist(temp[j]))))
    }
    print(centroid)
    centroids[i,] <- centroid
  }
  print(centroids)
}

再次,据我了解,我实际上不应该像这样编写此部分,而应该使用通用的编写方法,这样做会更快。

总体而言,我的完整算法在虹膜数据集上运行的时间为2.24秒,而我在python中的实现在0.03秒内运行

所以我在这里显然做错了,有些事情需要花费大量时间,但我无法亲自上手

预先感谢您的回答, Shraneid

编辑: dput generated file

1 个答案:

答案 0 :(得分:3)

distance <- function(pt1, pt2){
  pt1 <- pt1[1:NUMBER_OF_FEATURES]
  pt2 <- pt2[1:NUMBER_OF_FEATURES]
  x <- sum((pt1 - pt2)^2)
  value <- sqrt(x)
  return(value)
} 

对于第二个功能,您正在循环内增长对象,这在R中很慢。

我想您的数据如下:

NUMBER_OF_CLUSTERS <- 2
NUMBER_OF_FEATURES <- 4 
n <- 100
set.seed(13)
ptlst <- data.frame(cluster = sample.int(NUMBER_OF_CLUSTERS, n, replace = T),
                    replicate(NUMBER_OF_FEATURES, rnorm(n)))
head(ptlst)
#   cluster         X1          X2         X3          X4
# 1       2  0.2731292 -2.84476384  0.6137843  2.10781521
# 2       1  0.7555251  1.71457759  0.4126145  1.57738122
# 3       1 -0.3490184 -1.22881682 -0.4588937  0.06149504
# 4       1 -0.5461908 -0.31407296 -0.6731785 -0.23792899
# 5       2  0.2343620 -0.06991232  0.1930543 -0.17730688
# 6       1 -0.2978282 -0.83760143  1.3829291 -1.17393025

所以我们可以尝试:

update_centroids <- function(ptlst){
  t(sapply(1:NUMBER_OF_CLUSTERS, function(i) {
    temp <- ptlst[which(ptlst$cluster == i),]
    colMeans(temp)
  }))
}
update_centroids(ptlst)
#      cluster          X1         X2          X3         X4
# [1,]       1  0.07365732 -0.0725119 -0.08745870 0.03406371
# [2,]       2 -0.24100628 -0.1044056  0.09288702 0.40949754

或使用data.table

require(data.table)
x <- as.data.table(ptlst)
x[, lapply(.SD, mean), keyby = cluster]
#    cluster          X1         X2          X3         X4
# 1:       1  0.07365732 -0.0725119 -0.08745870 0.03406371
# 2:       2 -0.24100628 -0.1044056  0.09288702 0.40949754

我建议您先阅读有关R的一些指南:

https://r4ds.had.co.nz/introduction.html https://cran.r-project.org/web/packages/data.table/vignettes/datatable-intro.html

在线上有很多有用的材料。