有多快的方法可以在多个条件下评估data.table [i,j]的 i 吗? (我的实际数据表有2M行。我想用data.table操作,而不是循环或lapply)。
例如,假设我有:
require(data.table)
data = data.table(seq(0.25,10, by = 0.25), rep(c("a","b","c","d"),10))
filter = seq(0,10,by=1)
我现在想过滤,说:
data[V1 > filter[4], .N, by=V2]
如何为过滤器的所有元素评估此表达式?
答案 0 :(得分:1)
不确定这是否会像在这里一样好用,但这或多或少都是这个想法:
1)首先,在V2
上设置密钥,然后设置V1
:
setkey(data, V2, V1)
2)接下来,我们添加一些我们稍后会使用的列:
data[, `:=`(id = 1:.N, N = .N), by=V2][, id := id[1L], by=list(V2,V1)]
第一个是从1到.N的运行序列,第二个是总计数。
3)这里发生了所有的魔法:
data[, V1.b := V1] ### this is necessary in 1.8.10 as V1 is not available in `j`
### fixed in v1.8.11
ans <- data[CJ(x=unique(V2), y=filter),
list(start=id[1L], end=N[1L],
actual_num=V1.b[1L], close_match=y[1L]),
roll="nearest"]
让我们分开并理解。第一部分CJ
创建了一个组合,用join
一次获得所有结果(因此首先是setkey
)。对于i
中的每个值,我们使用roll="nearest"
来确保我们确实获得匹配(最接近的值),然后记下start
,end
{符合的的{1}}和V1
值。那么为什么我们需要所有这些价值呢?特别是为什么y
和V1
?
4)现在,根据此结果,y
给出匹配所在的第一个位置,start
总是给出end
的元素总数。但是,有一个问题。如果您要查找的号码计为V2
且V1 > 5
中的最接近值为V1
(&gt; 5),那么5.5
位置是正确的。但是,如果start
的最接近值为V1
,那么我们将4.5
增加1,因为我们的匹配是前一行。
要接受很多......但是一步一步地做这件事应该会有所帮助。所以,基本上我们现在这样做:
start
我之前解释的正确(此处ans[actual_num <= close_match, start := start+1L]
为V1
)。
5)现在,我们可以actual_num
来获得总数:
end-start+1
6)清理:
ans[, tot_cnt := end-start+1L]
全部放在一起:
ans[, `:=`(start=NULL, end=NULL, close_match=NULL, actual_num=NULL)]
setnames(ans, 'V1', 'filter')
setkey(ans, filter)
在@ eddi的数据上运行此操作大约需要2.4秒。
答案 1 :(得分:1)
尝试以下方法:
data = data.table(val = seq(0.25,10, by = 0.25), grp = rep(c("a","b","c","d"),10))
filter = seq(0,10,by=1)
fl = data.table(filter, key = 'filter')
# to get strict inequality I subtracted a "small" number
# adjust it appropriately for your data
data[, max.filter := fl[J(val - 1e-7), .I, roll = Inf]$.I][,
lapply(seq_along(filter), function(i) sum(max.filter >= i)), by = grp]
# grp V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11
#1: a 10 9 8 7 6 5 4 3 2 1 0
#2: b 10 9 8 7 6 5 4 3 2 1 0
#3: c 10 9 8 7 6 5 4 3 2 1 0
#4: d 10 9 8 7 6 5 4 3 2 1 0
对具有2M行和200个过滤器值的数据进行测试需要花费5秒多的时间(而且,要进行比较,天真的lapply
需要花费超过一分钟):
N = 2e6
data = data.table(val = runif(N, 1, N), grp = sample(letters, N, T))
filter = seq(0, N, by = N/200)
fl = data.table(filter, key = 'filter')
system.time(data[, max.filter := fl[J(val - 1e-7), .I, roll = Inf]$.I][, lapply(seq_along(filter), function(i) sum(max.filter >= i)), by = grp])
# user system elapsed
# 5.24 0.00 5.41
system.time(lapply(filter, function(x) data[val > x, .N, by = grp]))
# user system elapsed
# 71.07 0.03 73.75
答案 2 :(得分:0)
如果我理解正确,OP希望计算每个grp
以及filter
中给出的每个阈值超过阈值的值的数量。
这可以使用non-equi joins
版本v1.9.8(2016年11月25日CRAN)上的data.table
来解决。
这是两个变种。第一个简单地计算高于每个阈值的所有值,这意味着一些值被比较并计数多次。
第二种变体试图避免多重比较并计算每个区间内的值,并在第二步计算累积和。
基准测试显示加速比较小。
aggregated <- data[CJ(grp = unique(grp), lb = filter),
on = .(grp, val > lb), .N, by = .EACHI]
dcast(aggregated, grp ~ val)
grp 0 1 2 3 4 5 6 7 8 9 10 1: a 10 9 8 7 6 5 4 3 2 1 0 2: b 10 9 8 7 6 5 4 3 2 1 0 3: c 10 9 8 7 6 5 4 3 2 1 0 4: d 10 9 8 7 6 5 4 3 2 1 0
请注意,列标题表示阈值。
fl <- CJ(grp = unique(data$grp), lower = filter)[
, upper := shift(lower, type = "lead", fill = Inf)][]
tmp <- data[fl, on = .(grp, val > lower, val <= upper), .N, by = .EACHI ][
order(-val), .(val, N = cumsum(N)), by = grp][]
dcast(tmp, grp ~ val)
grp 0 1 2 3 4 5 6 7 8 9 10 1: a 10 9 8 7 6 5 4 3 2 1 0 2: b 10 9 8 7 6 5 4 3 2 1 0 3: c 10 9 8 7 6 5 4 3 2 1 0 4: d 10 9 8 7 6 5 4 3 2 1 0
不幸的是,我无法获得Arun和Eddi提供的解决方案来处理实际的data.table版本1.10.5。因此,此基准测试只是使用Eddi的基准数据比较两个非equi join 变体。
# create benchmark data
N = 2e6
set.seed(123L)
data = data.table(val = runif(N, 1, N), grp = sample(letters, N, T))
filter = seq(0, N, by = N/200)
# define check function
my_check <- function(values) {
all(sapply(values[-1], function(x) identical(values[[1]], x)))
}
#run benchmark
microbenchmark::microbenchmark(
nej1 = {
aggregated <- data[CJ(grp = unique(grp), lb = filter), on = .(grp, val > lb), .N, by = .EACHI]
dcast(aggregated, grp ~ val, value.var = "N")
},
nej2 =
{
fl <- CJ(grp = unique(data$grp), lower = filter)[, upper := shift(lower, type = "lead", fill = Inf)][]
tmp <- data[fl, on = .(grp, val > lower, val <= upper), .N, by = .EACHI ][
order(-val), .(val, N = cumsum(N)), by = grp][]
dcast(tmp, grp ~ val, value.var = "N")
},
check = my_check,
times = 20L
)
Unit: milliseconds expr min lq mean median uq max neval cld nej1 358.6231 368.8033 389.6501 378.1158 391.8690 556.3323 20 b nej2 347.6321 360.4038 365.2218 366.9392 370.1697 382.8185 20 a
变体2的速度优势不到5%。如果使用计算上更密集的聚合函数,这可能会有所不同。