R - 获取向量中最大n个元素的索引的最快方法

时间:2013-08-26 18:39:38

标签: r sorting vector

假设我有一个包含100万个元素的巨大向量x,我想找到最多30个元素的索引。我不特别在意结果是否在这30个元素中排序,只要它们是整个向量中的最大值30。使用order[x][1:30]似乎非常昂贵,因为它必须对整个向量进行排序。我考虑过使用partial中的sort选项,但sort会返回值,并且在指定index.return时不支持partial选项。有没有一种有效的方法来查找索引而不对整个向量进行排序?

6 个答案:

答案 0 :(得分:12)

我想使用sort partial参数和which添加混合方法:

whichpart <- function(x, n=30) {
  nx <- length(x)
  p <- nx-n
  xp <- sort(x, partial=p)[p]
  which(x > xp)
}

一些基准测试:

library("microbenchmark")
library("data.table")
library("compiler")

set.seed(123)
x <- rnorm(1e6)
y <- sample.int(1e6)


whichpart <- function(x, n=30) {
  nx <- length(x)
  p <- nx-n
  xp <- sort(x, partial=p)[p]
  which(x > xp)
}

cpwhichpart <- cmpfun(whichpart)

# using quicksort
quicksort <- function(x, n=30) {
  sort(x, method="quick", decreasing=TRUE, index.return=TRUE)$ix[1:n]
}

cpquicksort <- cmpfun(quicksort)

# @Mariam
whichsort <- function(x, n=30) {
  which(x >= sort(x, decreasing=TRUE)[30], arr.ind=TRUE)
}

cpwhichsort <- cmpfun(whichsort)

# @Ferdinand.kraft
top <- function(x, n=30) {
    result <- numeric()
    for(i in 1:n){
        j <- which.max(x)
        result[i] <- j
        x[j] <- -Inf
    }
    result
}

cptop <- cmpfun(top)

# @Tony Breyal
dtable <- function(x, n=30) {
  dt <- data.table(x=x, x.index=seq.int(x))
  setkey(dt, "x")
  dt$x.index[1:n]
}

cpdtable <- cmpfun(dtable)

# @Roland
roland <- cmpfun(function(x, n=30) {
  y <- rep(-Inf, n)
  for (i in seq_along(x)) {
    if (x[i] > y[1]) {
      y[1] <- x[i]
      y <- y[order(y)]
    }
  }
  y
})

## rnorm
microbenchmark(whichpart(x), cpwhichpart(x),
               quicksort(x), cpquicksort(x),
               whichsort(x), cpwhichsort(x),
               top(x), cptop(x),
               dtable(x), cpdtable(x),
               roland(x), times=10)

# Unit: milliseconds
#            expr        min         lq     median         uq        max neval
#    whichpart(x)   45.63544   46.05638   47.09077   49.68452   51.42065    10
#  cpwhichpart(x)   45.65996   45.77212   47.02808   48.07482   82.20458    10
#    quicksort(x)  100.90936  103.00783  105.17506  109.31784  139.83518    10
#  cpquicksort(x)  100.53958  102.78017  107.64470  138.96630  142.52882    10
#    whichsort(x)  148.86010  151.04350  155.80871  159.47063  184.56697    10
#  cpwhichsort(x)  149.05578  150.21183  151.36918  166.58342  173.87567    10
#          top(x)  146.10757  182.42089  184.53050  191.37293  193.62272    10
#        cptop(x)  155.14354  179.14847  184.52323  196.80644  220.21222    10
#       dtable(x) 1041.32457 1042.54904 1049.26096 1065.40606 1080.89969    10
#     cpdtable(x) 1042.08247 1043.54915 1051.76366 1084.14360 1310.26485    10
#       roland(x)  251.42885  261.47608  273.20838  295.09733  323.96257    10

## integer
microbenchmark(whichpart(y), cpwhichpart(y),
               quicksort(y), cpquicksort(y),
               whichsort(y), cpwhichsort(y),
               top(y), cptop(y),
               dtable(y), cpdtable(y),
               roland(y), times=10)

# Unit: milliseconds
#            expr       min        lq    median        uq       max neval
#    whichpart(y)  11.60703  11.76857  12.03704  12.52871  47.88526    10
#  cpwhichpart(y)  11.62885  11.75006  12.53724  13.88563  46.93677    10
#    quicksort(y)  88.14924  89.47630  92.42414 103.53439 137.44335    10
#  cpquicksort(y)  88.11544  89.15334  92.63420  94.42244 133.78006    10
#    whichsort(y) 122.34675 123.13634 124.91990 127.79134 131.43400    10
#  cpwhichsort(y) 121.85618 122.91653 125.45211 127.14112 158.61535    10
#          top(y) 163.06669 181.19004 211.11557 224.19237 239.63139    10
#        cptop(y) 163.37903 173.55113 209.46770 218.59685 226.81545    10
#       dtable(y) 499.50807 505.45513 514.55338 537.84129 604.86454    10
#     cpdtable(y) 491.70016 498.62664 525.05342 527.14666 580.19429    10
#       roland(y) 235.44664 237.52200 242.87925 268.34080 287.71196    10


identical(sort(quicksort(x)), whichpart(x))
# [1] TRUE

编辑:测试@ flodel的建议

# @flodel
whichpartrev <- function(x, n=30) {
  which(x >= -sort(-x, partial=n)[n])
}

microbenchmark(whichpart(x), whichpartrev(x), times=100)

# Unit: milliseconds
#             expr      min       lq   median       uq      max neval
#     whichpart(x) 45.44940 46.15011 46.51321 48.67986 80.63286   100
#  whichpartrev(x) 28.84482 31.30661 32.87695 62.37843 67.84757   100

microbenchmark(whichpart(y), whichpartrev(y), times=100)

# Unit: milliseconds
#             expr      min       lq   median       uq      max neval
#     whichpart(y) 11.56135 12.26539 13.05729 13.75199 43.78484   100
#  whichpartrev(y) 16.00612 16.73690 17.71687 19.04153 49.02842   100

答案 1 :(得分:1)

vec <- runif(1000000)
index <- which(vec >= sort(vec, decreasing=T)[30], arr.ind=TRUE)
vec[index]

答案 2 :(得分:1)

我不确定你是如何避免排序的。我认为?which.max可能会有所帮助。

无论如何,我会做类似以下的事情:

require(data.table)
set.seed(42)
n <- 1000000
x <- rnorm(n)
dt <- data.table(x = x, x.index = seq.int(1, n))
setkey(dt, "x")
tail(dt, 30)

#            x x.index
# 1: 0.9999712  270177
# 2: 0.9999715  521060
# 3: 0.9999723  863876
# 4: 0.9999757  622734
# 5: 0.9999761   48337
# 6: 0.9999764  699984
# 7: 0.9999766  264473
# 8: 0.9999770  212981
# 9: 0.9999782  911943
# 10: 0.9999874  330250
# 11: 0.9999876  695213
# 12: 0.9999879  219101
# 13: 0.9999880  144000
# 14: 0.9999880  459676
# 15: 0.9999887  910525
# 16: 0.9999894  902172
# 17: 0.9999900  474633
# 18: 0.9999905  360481
# 19: 0.9999920  985058
# 20: 0.9999925   17169
# 21: 0.9999926  424703
# 22: 0.9999927  448196
# 23: 0.9999929  254084
# 24: 0.9999932  468090
# 25: 0.9999940  480390
# 26: 0.9999961  765489
# 27: 0.9999966  556407
# 28: 0.9999968  860100
# 29: 0.9999982  879843
# 30: 0.9999989  507889

答案 3 :(得分:0)

使用这个简单的函数,我设法使用长度为10 ^ 7或更长的矢量来保存几秒钟:

top <- function(x, n=30){
    result <- numeric()
    for(i in 1:n){
        j <- which.max(x)
        result[i] <- x[j]
        x[j] <- -Inf
    }
    result
}

我的结果:

> x <- runif(1e7)
> system.time(y <- sort(x,decreasing=TRUE)[1:30])
   user  system elapsed 
   3.30    0.04    3.39 
> system.time(z <- top(x))
   user  system elapsed 
   2.49    0.58    3.12 
> x <- runif(1e8)
> system.time(y <- sort(x,decreasing=TRUE)[1:30])
   user  system elapsed 
  41.74    1.15   43.62 
> system.time(z <- top(x))
   user  system elapsed 
  25.96    7.61   34.43 

答案 4 :(得分:0)

set.seed(42)
x <- rnorm(1e6)

fun1 <- function(x) x[order(x, decreasing = TRUE)][1:30]

library(compiler)
fun2 <- cmpfun(function(x) {
  y <- rep(-Inf,30)
  for (i in seq_along(x)) {
    if (x[i] > y[1]) {
      y[1] <- x[i]
      y <- y[order(y)]
    }
  }
  y
})

library(microbenchmark)
microbenchmark(
y1 <- fun1(x),
y2 <- fun2(x),
times=5)
#Unit: milliseconds
#         expr      min       lq   median       uq      max neval
#y1 <- fun1(x) 400.1574 411.8172 418.7872 425.8027 426.2981     5
#y2 <- fun2(x) 255.7817 258.2374 258.8088 259.4630 290.6068     5


identical(sort(y1), sort(y2))
#[1] TRUE

答案 5 :(得分:-1)

我会按降序排序,然后先取n:

 x = rnorm(1000000, 10, 5)
 x = sort(x, decreasing = TRUE)
 n = 30
 print(head(x, n))