大型数据集中的快速子集/查找/过滤

时间:2017-10-24 17:34:50

标签: r performance dplyr data.table subset

我需要在因子和数字列的大(多GB)表中重复查找“最近”行。使用dplyr,它看起来像这样:

df <- data.frame(factorA = rep(letters[1:3], 100000),
             factorB = sample(rep(letters[1:3], 100000), 
                              3*100000, replace = FALSE),
             numC = round(rnorm(3*100000), 2),
             numD = round(rnorm(3*100000), 2))

closest <- function(ValueA, ValueB, ValueC, ValueD) {
  df_sub <- df %>%
    filter(factorA == ValueA,
           factorB == ValueB,
           numC >= 0.9 * ValueC,
           numC <= 1.1 * ValueC,
           numD >= 0.9 * ValueD,
           numD <= 1.1 * ValueD)

  if (nrow(df_sub) == 0) stop("Oh-oh, no candidates.")

  minC <- df_sub[which.min(abs(df_sub$numC - ValueC)), "numC"]

  df_sub %>%
    filter(numC == minC) %>%
    slice(which.min(abs(numD - ValueD))) %>%
    as.list() %>%
    return()
}

以下是上述基准:

> microbenchmark(closest("a", "b", 0.5, 0.6))
Unit: milliseconds
                        expr      min       lq     mean   median       uq      max neval
 closest("a", "b", 0.5, 0.6) 25.20927 28.90623 35.16863 34.59485 35.25468 108.3489   100

优化此功能以获得速度的最佳方法是什么?有内存的RAM,即使内存中有大df,但考虑到对此函数的多次调用,我希望尽可能快地完成。

使用data.table代替dplyr帮助吗?

到目前为止,我尝试了两种优化:

dt <- as.data.table(df)

closest2 <- function(ValueA, ValueB, ValueC, ValueD) {
  df_sub <- df %>%
    filter(factorA == ValueA,
           factorB == ValueB,
           dplyr::between(numC, 0.9 * ValueC, 1.1 * ValueC),
           dplyr::between(numD, 0.9 * ValueD, 1.1 * ValueD))

  if (nrow(df_sub) == 0) stop("Oh-oh, no candidates.")

  minC <- df_sub[which.min(abs(df_sub$numC - ValueC)), "numC"]

  df_sub %>%
    filter(numC == minC) %>%
    slice(which.min(abs(numD - ValueD))) %>%
    as.list() %>%
    return()
}

closest3 <- function(ValueA, ValueB, ValueC, ValueD) {

  dt_sub <- dt[factorA == ValueA & 
                 factorB == ValueB & 
                 numC %between% c(0.9 * ValueC, 1.1 * ValueC) &
                 numD %between% c(0.9 * ValueD, 1.1 * ValueD)]

  if (nrow(dt_sub) == 0) stop("Oh-oh, no candidates.")

  dt_sub[abs(numC - ValueC) == min(abs(numC - ValueC))][which.min(abs(numD - ValueD))] %>%
    as.list() %>%
    return()
}

基准:

> microbenchmark(closest("a", "b", 0.5, 0.6), closest2("a", "b", 0.5, 0.6), closest3("a", "b", 0.5, 0.6))
Unit: milliseconds
                         expr      min       lq     mean   median       uq       max neval cld
  closest("a", "b", 0.5, 0.6) 25.15780 25.62904 36.52022 34.68219 35.27116 155.31924   100   c
 closest2("a", "b", 0.5, 0.6) 22.14465 22.46490 27.81361 31.40918 32.04427  35.79021   100  b 
 closest3("a", "b", 0.5, 0.6) 13.52094 13.77555 20.04284 22.70408 23.41452 142.73626   100 a  

这可以更优化吗?

2 个答案:

答案 0 :(得分:3)

如果你可以并行调用许多元组值而不是顺序调用...

set.seed(1)
DF <- data.frame(factorA = rep(letters[1:3], 100000),
             factorB = sample(rep(letters[1:3], 100000), 
                              3*100000, replace = FALSE),
             numC = round(rnorm(3*100000), 2),
             numD = round(rnorm(3*100000), 2))

library(data.table)
DT = data.table(DF)

f = function(vA, vB, nC, nD, dat = DT){

  rs <- dat[.(vA, vB, nC), on=.(factorA, factorB, numC), roll="nearest",
    .(g = .GRP, r = .I, numD), by=.EACHI][.(seq_along(vA), nD), on=.(g, numD), roll="nearest", mult="first", 
    r]

  df[rs]
}

# example usage
mDT = data.table(vA = c("a", "b"), vB = c("c", "c"), nC = c(.3, .5), nD = c(.6, .8))

mDT[, do.call(f, .SD)]

#    factorA factorB numC numD
# 1:       a       c  0.3 0.60
# 2:       b       c  0.5 0.76

与必须逐行运行的其他解决方案相比......

# check the results match
library(magrittr)
dt = copy(DT)
mDT[, closest3(vA, vB, nC, nD), by=.(mr = seq_len(nrow(mDT)))]

#    mr factorA factorB numC numD
# 1:  1       a       c  0.3 0.60
# 2:  2       b       c  0.5 0.76

# check speed for a larger number of comparisons

nr = 100
system.time( mDT[rep(1:2, each=nr), do.call(f, .SD)] )
#    user  system elapsed 
#    0.07    0.00    0.06 

system.time( mDT[rep(1:2, each=nr), closest3(vA, vB, nC, nD), by=.(mr = seq_len(nr*nrow(mDT)))] )
#    user  system elapsed 
#   10.65    2.30   12.60 

工作原理

对于.(vA, vB, nC)中的每个元组,我们会精确查找与vAvB匹配的行,然后“滚动”到最近的值nC - 这不会不太符合OP的规则(在nC * [0.9,1.1]的范围内查看),但该规则可以很容易地应用于事后。对于每个匹配,我们会使用元组的“组号”,.GRP,匹配的行号以及这些行上的numD值。

然后我们加入组号和nD,完全匹配前者并滚动到最接近后者。如果有多个最接近的匹配项,我们会使用mult="first"进行第一次匹配。

然后我们可以获取每个元组匹配的行号,并在原始表中查找。

性能

因此矢量化解决方案似乎具有很大的性能优势,就像往常一样R.

如果你一次只能传递~5个元组(对于OP而言)而不是200个,那么这个方法与which.min相比可能会有好处,并且类似于二进制搜索,因为@F .Privé在评论中提出。

正如@ HarlanNelson的回答所述,在表中添加索引可能会进一步提高性能。请参阅他的回答和?setindex

修正numC滚动到一个值

感谢OP确定此问题:

DT2 = data.table(id = "A", numC = rep(c(1.01,1.02), each=5), numD = seq(.01,.1,.01))
DT2[.("A", 1.011), on=.(id, numC), roll="nearest"]
#    id  numC numD
# 1:  A 1.011 0.05

在这里,我们看到一行,但我们应该看到五行。一个修复(虽然我不知道为什么)转换为整数:

DT3 = copy(DT2)
DT3[, numC := as.integer(numC*100)]
DT3[, numD := as.integer(numD*100)]
DT3[.("A", 101.1), on=.(id, numC), roll="nearest"]
#    id numC numD
# 1:  A  101    1
# 2:  A  101    2
# 3:  A  101    3
# 4:  A  101    4
# 5:  A  101    5

答案 1 :(得分:2)

这是作弊因为我在基准测试之前编制索引,但我假设您将在相同的data.table上多次运行查询。

library(data.table)
dt<-as.data.table(df)
setkey(dt,factorA,factorB)

closest2 <- function(ValueA, ValueB, ValueC, ValueD) {

  dt<-dt[.(ValueA,ValueB), on = c('factorA','factorB')]
  df_sub <- dt %>%
    filter( numC >= 0.9 * ValueC,
           numC <= 1.1 * ValueC,
           numD >= 0.9 * ValueD,
           numD <= 1.1 * ValueD)

  if (nrow(df_sub) == 0) stop("Oh-oh, no candidates.")

  minC <- df_sub[which.min(abs(df_sub$numC - ValueC)), "numC"]

  df_sub %>%
    filter(numC == minC) %>%
    slice(which.min(abs(numD - ValueD))) %>%
    as.list() %>%
    return()
}

library(microbenchmark)
microbenchmark(closest("a", "b", 0.5, 0.6))
microbenchmark(closest2("a", "b", 0.5, 0.6))


Unit: milliseconds
                       expr      min       lq     mean   median       uq      max neval
closest("a", "b", 0.5, 0.6) 20.29775 22.55372 28.08176 23.20033 25.42154 127.7781   100
Unit: milliseconds
                       expr      min       lq     mean   median      uq      max neval
 closest2("a", "b", 0.5, 0.6) 8.595854 9.063261 9.929237 9.396594 10.0247 16.92655   100