我想查看一个列表,然后检查该项是否是列表中最常用的项目。与Python相比,我目前拥有的解决方案非常慢。有没有一种有效的方法来加速它?
dat<-data.table(sample(1:50,10000,replace=T))
k<-1
correct <- 0 # total correct predictions
for (i in 2:(nrow(dat)-1)) {
if (dat[i,V1] %in% dat[1:(i-1),.N,by=V1][order(-N),head(.SD,k)][,V1]) {
correct <- correct + 1
}
}
更一般地说,我最终想看一个项目是否是最多的项目之一 频繁的项目直到一个点,或者它有一个k最高值,直到一个点。
为了比较,这是Python中一个非常快速的实现:
dat=[random.randint(1,50) for i in range(10000)]
correct=0
k=1
list={}
for i in dat:
toplist=heapq.nlargest(k,list.iteritems(),key=operator.itemgetter(1))
toplist=[j[0] for j in toplist]
if i in toplist:
correct+=1
if list.has_key(i):
list[i]=list[i]+1
else:
list[i]=1
答案 0 :(得分:3)
这是我到目前为止所得到的(我的解决方案是f3):
set.seed(10)
dat<-data.table(sample(1:3,100,replace=T))
k<-1
f3 <- function(dat) {
correct <- 0 # total correct predictions
vf <- factor(dat$V1)
v <- as.integer(vf)
tabs <- integer(max(v))
for (i in 2:(nrow(dat)-1)) {
tabs[v[i-1]] <- tabs[v[i-1]] + 1
#print(tabs)
#print(v[1:i])
if (match(v[i],order(tabs,decreasing = T))<=k) {
correct <- correct + 1
}
#print(correct)
#print('')
}
correct
}
f1 <- function(dat) {
correct <- 0 # total correct predictions
for (i in 2:(nrow(dat)-1)) {
if (dat[i,V1] %in% dat[1:(i-1),.N,by=V1][order(-N),head(.SD,k)]) {
correct <- correct + 1
}
}
correct
}
library(rbenchmark)
print(f1(dat)==f3(dat))
library(rbenchmark)
benchmark(f1(dat),f3(dat),replications=10)
基准测试结果:
test replications elapsed relative user.self sys.self user.child sys.child
1 f1(dat) 10 2.939 163.278 2.931 0.008 0 0
2 f3(dat) 10 0.018 1.000 0.018 0.000 0 0
令人鼓舞,但f3
有两个问题:
它并不总是提供与OP算法相同的答案,因为关系的处理方式不同,
还有很大的改进空间,因为tabs
每次重新排序。
答案 1 :(得分:3)
条件自动为真,直至观察到k + 1值:
startrow <- dat[,list(.I,.GRP),by=V1][.GRP==k+1]$.I[1]
correct <- rep(0L,length(v))
correct[1:(startrow-1)] <- 1L
您可以预先计算V1
值到目前为止的出现次数:
ct <- dat[,ct:=1:.N,by=V1]$ct
在循环过程中,我们可以检查第k个最常见的值是否被当前值击倒。
startrow
:topk <- sort(tapply(ct[1:(startrow-1)],v[1:(startrow-1)],max))
thresh <- unname(topk[1])
startrow
循环到length(v)
,每当达到阈值时,更新correct
(此处为向量,而不是运行总和);如果达到阈值,则更新top-k俱乐部和该俱乐部中尚未存在该值。那就是它;其余的只是细节。这是我的职责:
ff <- function(dat){
vf <- factor(dat$V1)
v <- as.integer(vf)
ct <- dat[,ct:=1:.N,by=V1]$ct
n <- length(v)
ct <- setNames(ct,v)
startrow <- dat[,list(.I,.GRP),by=V1][.GRP==k+1]$.I[1]
topk <- sort(tapply(ct[1:(startrow-1)],v[1:(startrow-1)],max))
thresh <- unname(topk[1])
correct <- rep(0L,n)
correct[1:(startrow-1)] <- 1L
for (i in startrow:n) {
cti = ct[i]
if ( cti >= thresh ){
correct[i] <- 1L
if ( cti > thresh & !( names(cti) %in% names(topk) ) ){
topk <- sort(c(cti,topk))[-1]
thresh <- unname(topk[1])
}
}
}
sum(correct)
}
这是非常快的,但与@ MaratTalipov和OP的结果不同:
set.seed(1)
dat <- data.table(sample(1:50,10000,replace=T))
k <- 5
f1(dat) # 1012
f3(dat) # 1015
ff(dat) # 1719
这是我的基准(不包括f1()
封装的OP方法,因为我不耐烦):
> benchmark(f3(dat),ff(dat),replications=10)[,1:5]
test replications elapsed relative user.self
1 f3(dat) 10 2.68 2.602 2.67
2 ff(dat) 10 1.03 1.000 1.03
我的功能提供了比@ Marat和OP更多的匹配,因为它允许阈值处的关系计为&#34;正确&#34;,而他们只计算匹配最多k由R&#39; s order
函数使用的任何算法选择的值。
答案 2 :(得分:3)
[新解决方案]
dplyr
有一个快速且非常简单的k=1
解决方案。下面的fC1
平等对待关系,即没有打破平局。你会看到你可以对它施加任何打破平局的规则。而且,它真的很快。
library(dplyr)
fC1 <- function(dat){
dat1 <- tbl_df(dat) %>%
group_by(V1) %>%
mutate(count=row_number()-1) %>% ungroup() %>% slice(2:n()-1) %>%
filter(count!=0) %>%
mutate(z=cummax(count)) %>%
filter(count==z)
z <- dat1$z
length(z)
}
set.seed(1234)
dat<-data.table(sample(1:5000, 100000, replace=T))
system.time(a1 <- fC1(dat))[3] #returns 120
elapsed
0.04
system.time(a3m <- f3m(dat, 1))[3] #returns 29, same to the Python result which runs about 60s
elapsed
89.72
system.time(a3 <- f3(dat, 1))[3] #returns 31.
elapsed
95.07
您可以自由地对fC1
的结果施加一些打破平局规则,以达成不同的解决方案。例如,为了获得f3m
或f3
个解决方案,我们会限制某些行的选择,如下所示
fC1_ <- function(dat){
b <- tbl_df(dat) %>%
group_by(V1) %>%
mutate(count=row_number()-1) %>%
ungroup() %>%
mutate(L=cummax(count+1))# %>%
b1 <- b %>% slice(2:(n()-1)) %>%
group_by(L) %>%
slice(1) %>%
filter(count+1>=L& count>0)
b2 <- b %>% group_by(L) %>%
slice(1) %>%
ungroup() %>%
select(-L) %>%
mutate(L=count)
semi_join(b1, b2, by=c("V1", "L")) %>% nrow
}
set.seed(1234)
dat <- data.table(sample(1:50,10000,replace=T))
fC1_(dat)
#[1] 218
f3m(dat, 1)
#[1] 217
f3(dat, 1)
#[1] 218
以及前面的例子
set.seed(1234)
dat<-data.table(sample(1:5000, 100000, replace=T))
system.time(fC1_(dat))[3];fC1_(dat)
#elapsed
# 0.05
#[1] 29
不知何故,我无法扩展一般k>1
的解决方案,所以我使用了Rcpp。
#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
std::vector<int> countrank(std::vector<int> y, int k) {
std::vector<int> v(y.begin(), y.begin() + k);
std::make_heap(v.begin(), v.end());
std::vector<int> count(y.size());
for(int i=0; i < y.size(); i++){
if(y[i]==0){count[i]=0;}
else{
v.push_back(y[i]); std::push_heap(v.begin(), v.end());
std::pop_heap(v.begin(), v.end()); v.pop_back();
std::vector<int>::iterator it = std::find (v.begin(), v.end(), y[i]);
if (it != v.end()) {count[i]=1;};
}
}
return count;
}
对于k=1
,值得注意的是fC1
至少与以下Rcpp版本fCpp
一样快。
fCpp <- function(dat, k) {
dat1 <- tbl_df(dat) %>%
group_by(V1) %>%
mutate(count=row_number())
x <- dat1$V1
y <- dat1$count-1
z <- countrank(-y, k)
sum(z[2:(nrow(dat)-1)])
}
同样,你可以轻松地施加任何打破平局的规则。
[f3, f3m
函数]
f3
来自@Marat Talipov,f3m
是对它的一些修正(虽然看起来多余)。
f3m <- function(dat, k){
n <- nrow(dat)
dat1 <- tbl_df(dat) %>%
group_by(V1) %>%
mutate(count=row_number())
x <- dat1$V1
y <- dat1$count
rank <- rep(NA, n)
tablex <- numeric(max(x))
for(i in 2:(n-1)){
if(y[i]==1){rank[i]=NA} #this condition was originally missing
else{
tablex[x[i-1]] = y[i-1]
rank[i]=match(x[i], order(tablex, decreasing = T))
}
}
rank <- rank[2:(n-1)]
sum(rank<=k, na.rm=T)
}
请参阅早期解决方案的编辑历史记录。
答案 3 :(得分:2)
这个解决方案怎么样:
# unique values
unq_vals <- sort(dat[, unique(V1)])
# cumulative count for each unique value by row
cum_count <- as.data.table(lapply(unq_vals, function(x) cumsum(dat$V1==x)))
# running ranking for each unique value by row
cum_ranks <- t(apply(-cum_count, 1, rank, ties.method='max'))
现在,(例如)第8次观察的第2个唯一值的等级存储在:
cum_ranks[8, 2]
您可以按行获取每个项目的排名(并将其显示在可读表中)。如果对于第i行rank
&lt; = k,那么V1
的第i项是观察到的第k个最频繁项目之一。
dat[, .(V1, rank=sapply(1:length(V1), function(x) cum_ranks[x, V1[x]]))]
第一个代码块在我的机器上仅需0.6883929秒(根据粗略的now <- Sys.time(); [code block in here]; Sys.time() - now
时间),dat <- data.table(sample(1:50, 10000, replace=T))