优化R中最近搜索函数的速度

时间:2016-12-30 00:28:48

标签: r list matrix

我试图让这个功能更快(理想情况下使用RcppAmadillo或其他替代方案)。 myfun 采用矩阵 mat ,它可以变得非常大,但总是两列。 myfun 找到矩阵中每行最接近的行,每行的绝对值为+1或-1

如下例所示,第一排垫子 3,3 。因此, myfun 将输出一个列表,其中行 2 3 最接近第1行,但不是行 5 ,这是+2远。

library(microbenchmark)
dim(mat)
[1] 1000    2
head(mat)
      x y
[1,]  3 3
[2,]  3 4
[3,]  3 2
[4,]  7 3
[5,]  4 4
[6,] 10 1

output
[[1]]
[1] 2 3

[[2]]
[1] 1

[[3]]
[1] 1

[[4]]
integer(0)

[[5]]
integer(0)

[[6]]
integer(0)


microbenchmark( myfun(mat), times =  100) #mat of 1000 rows
#    Unit: milliseconds
#             expr      min       lq     mean   median       uq      max neval
#    myfun(mat) 89.30126 90.28618 95.50418 90.91281 91.50875 180.1505   100

microbenchmark( myfun(mat), times =  100) #mat of 10,000 rows
#    Unit: seconds
#             expr      min       lq     mean   median       uq      max neval
#    myfun(layout.old) 5.912633 5.912633 5.912633 5.912633 5.912633 5.912633     1

这就是 myfun 的样子

myfun = function(x){
    doo <- function(j) {
        j.mat <- matrix(rep(j, length = length(x)), ncol = ncol(x), byrow = TRUE)
        j.abs <- abs(j.mat - x)
        return(which(rowSums(j.abs) == 1))
    }
    return(apply(x, 1, doo))
}

1 个答案:

答案 0 :(得分:1)

下面,我有一个base R解决方案,比OP提供的myfun快得多。

myDistOne <- function(m) {
    v1 <- m[,1L]; v2 <- m[,2L]
    rs <- rowSums(m)
    lapply(seq_along(rs), function(x) {
        t1 <- which(abs(rs[x] - rs) == 1)
        t2 <- t1[which(abs(v1[x] - v1[t1]) <= 1)]
        t2[which(abs(v2[x] - v2[t2]) <= 1)]
    })
}

以下是一些基准:

library(microbenchmark)
set.seed(9711)
m1 <- matrix(sample(50, 2000, replace = TRUE), ncol = 2)   ## 1,000 rows
microbenchmark(myfun(m1), myDistOne(m1))
Unit: milliseconds
         expr      min       lq     mean   median       uq      max neval cld
    myfun(m1) 78.61637 78.61637 80.47931 80.47931 82.34225 82.34225     2   b
myDistOne(m1) 27.34810 27.34810 28.18758 28.18758 29.02707 29.02707     2  a 

identical(myfun(m1), myDistOne(m1))
[1] TRUE

m2 <- matrix(sample(200, 20000, replace = TRUE), ncol = 2)  ## 10,000 rows
microbenchmark(myfun(m2), myDistOne(m2))
Unit: seconds
         expr      min       lq     mean   median       uq      max neval cld
    myfun(m2) 5.219318 5.533835 5.758671 5.714263 5.914672 7.290701   100   b
myDistOne(m2) 1.230721 1.366208 1.433403 1.419413 1.473783 1.879530   100  a 

identical(myfun(m2), myDistOne(m2))
[1] TRUE

这是一个非常大的例子:

m3 <- matrix(sample(1000, 100000, replace = TRUE), ncol = 2) ## 50,000 rows
system.time(testJoe <- myDistOne(m3))
   user  system elapsed 
 26.963  10.988  37.973 

system.time(testUser <- myfun(m3))
   user  system elapsed 
148.444  33.297 182.639 

identical(testJoe, testUser)
[1] TRUE

我确信有更快的解决方案。也许通过对rowSums进行排序并从那里开始工作可以看到改进(它也可能变得非常混乱)。


更新

正如我所预测的那样,从排序的rowSums中工作要快得多(而且更加丑陋!)

myDistOneFast <- function(m) {
    v1 <- m[,1L]; v2 <- m[,2L]
    origrs <- rowSums(m)
    mySort <- order(origrs)
    rs <- origrs[mySort]
    myDiff <- c(0L, diff(rs))
    brks <- which(myDiff > 0L)
    lenB <- length(brks)
    n <- nrow(m)
    myL <- vector("list", length = n)

    findRows <- function(v, s, r, u1, u2) {
        lapply(v, function(x) {
            sx <- s[x]
            tv1 <- s[r]
            tv2 <- tv1[which(abs(u1[sx] - u1[tv1]) <= 1)]
            tv2[which(abs(u2[sx] - u2[tv2]) <= 1)]
        })
    }

    t1 <- brks[1L]; t2 <- brks[2L]
    ## setting first index in myL
    myL[mySort[1L:(t1-1L)]] <- findRows(1L:(t1-1L), mySort, t1:(t2-1L), v1, v2)
    k <- t0 <- 1L

    while (k < (lenB-1L)) {
        t1 <- brks[k]; t2 <- brks[k+1L]; t3 <- brks[k+2L]
        vec <- t1:(t2-1L)
        if (myDiff[t1] == 1L) {
            if (myDiff[t2] == 1L) {
                myL[mySort[vec]] <- findRows(vec, mySort, c(t0:(t1-1L), t2:(t3-1L)), v1, v2)
            } else {
                myL[mySort[vec]] <- findRows(vec, mySort, t0:(t1-1L), v1, v2)
            }
        } else if (myDiff[t2] == 1L) {
            myL[mySort[vec]] <- findRows(vec, mySort, t2:(t3-1L), v1, v2)
        }
        if (myDiff[t2] > 1L) {
            if (myDiff[t3] > 1L) {
                k <- k+2L; t0 <- t2
            } else {
                k <- k+1L; t0 <- t1
            }
        } else {k <- k+1L; t0 <- t1}
    }

    ## setting second to last index in myL
    if (k == lenB-1L) {
        t1 <- brks[k]; t2 <- brks[k+1L]; t3 <- n+1L; vec <- t1:(t2-1L)
        if (myDiff[t1] == 1L) {
            if (myDiff[t2] == 1L) {
                myL[mySort[vec]] <- findRows(vec, mySort, c(t0:(t1-1L), t2:(t3-1L)), v1, v2)
            } else {
                myL[mySort[vec]] <- findRows(vec, mySort, t0:(t1-1L), v1, v2)
            }
        } else if (myDiff[t2] == 1L) {
            myL[mySort[vec]] <- findRows(vec, mySort, t2:(t3-1L), v1, v2)
        }
        k <- k+1L; t0 <- t1
    }

    t1 <- brks[k]; vec <- t1:n
    if (myDiff[t1] == 1L) {
        myL[mySort[vec]] <- findRows(vec, mySort, t0:(t1-1L), v1, v2)
    }

    myL
}

结果甚至没有结束。在非常大的矩阵上,myDistOneFast比OP的原始myfun快100倍以上并且也可以很好地扩展。以下是一些基准:

microbenchmark(OP = myfun(m1), Joe = myDistOne(m1), JoeFast = myDistOneFast(m1))
Unit: milliseconds
   expr      min       lq     mean   median       uq       max neval
     OP 57.60683 59.51508 62.91059 60.63064 61.87141 109.39386   100
    Joe 22.00127 23.11457 24.35363 23.87073 24.87484  58.98532   100
JoeFast 11.27834 11.99201 12.59896 12.43352 13.08253  15.35676   100

microbenchmark(OP = myfun(m2), Joe = myDistOne(m2), JoeFast = myDistOneFast(m2))
Unit: milliseconds
   expr       min        lq      mean    median        uq       max neval
     OP 4461.8201 4527.5780 4592.0409 4573.8673 4633.9278 4867.5244   100
    Joe 1287.0222 1316.5586 1339.3653 1331.2534 1352.3134 1524.2521   100
JoeFast  128.4243  134.0409  138.7518  136.3929  141.3046  172.2499   100

system.time(testJoeFast <- myDistOneFast(m3))
user  system elapsed 
0.68    0.00    0.69   ### myfun took over 100s!!!

为了测试相等性,我们必须对每个索引向量进行排序。我们也无法使用identical进行比较,因为myL被初始化为空列表,因此某些索引包含NULL个值(这些值对应于integer(0) in myfunmyDistOne)的结果。

testJoeFast <- lapply(testJoeFast, sort)
all(sapply(1:50000, function(x) all(testJoe[[x]]==testJoeFast[[x]])))
[1] TRUE

unlist(testJoe[which(sapply(testJoeFast, is.null))])
integer(0)

以下是500,000行的示例:

set.seed(42)
m4 <- matrix(sample(2000, 1000000, replace = TRUE), ncol = 2)
system.time(myDistOneFast(m4))
   user  system elapsed 
  10.84    0.06   10.94

以下概述了该算法的工作原理:

  1. 计算rowSums
  2. 订购rowSums(即从有序矢量的原始矢量返回索引)
  3. 致电diff
  4. 标记每个非零实例
  5. 确定小范围内的哪些索引满足OP的请求
  6. 使用2中计算的有序向量来确定原始索引
  7. 这比每次将一个rowSum与所有rowSum进行比较要快得多。