通过聚合和过滤进行高效的交叉连接

时间:2017-08-08 00:06:40

标签: r data.table dplyr sqldf cross-join

根据标题,我希望与表进行交叉连接,该表执行聚合功能并过滤表中的几个变量。

我有以下类似的数据:

library(dplyr)
library(data.table)
library(sqldf)

sales <-  data.frame(salesx = c(3000, 2250,850,1800,1700,560,58,200,965,1525)
                     ,week = seq(from = 1, to = 10, by = 1)
                     ,uplift = c(0.04)
                     ,slope = c(100)
                     ,carryover = c(.35))
spend <- data.frame(spend = seq(from = 1, to = 50000, by = 1))

tempdata <- merge(spend,sales,all=TRUE)
tempdata$singledata <- as.numeric(1) 

以下是我通过基于sql的解决方案尝试完成的示例:

newdata <- sqldf("select a.spend, a.week,
                 sum(case when b.week > a.week
                 then b.salesx*(b.uplift*(1-exp(-(power(b.singledata,b.week-a.week)/b.slope))))/b.spend
                 else 0.0 end) as calc3
                 from tempdata a, tempdata b  
                 where a.spend = b.spend 
                 group by a.spend,a.week")

这提供了我想要的结果,但它有点慢,特别是我的真实数据集大约有100万条记录。对a)如何加速sqldf函数提出一些建议会很棒;或者b)使用更有效的data.table / dplyr方法(我无法理解交叉连接/聚合/过滤三元组问题)。

下面非等同加入解决方案的清晰度:

我对非equi连接解决方​​案有几个问题 - 输出很好而且非常快。在了解代码是如何工作的时候,我将其分解为:

breakdown <- setDT(tempdata)[tempdata, .(spend, uplift, slope,carryover,salesx,  singledata, week, i.week,x.week, i.salesx,x.salesx, x.spend, i.spend), on=.(spend, week > week)]

根据细目,为了与原始计算保持一致,它应该是:

x.salesx*(uplift*(1.0-exp(-(`^`(singledata,x.week-week)/slope))))/i.spend

这一点不明显的原因是因为我使用的例子中的“幂”部分并没有真正做任何事情(总是1)。使用的实际计算是(向数据添加转移变量):

SQL

b.salesx*(b.uplift*(1-exp(-(power((b.singledata*b.carryover),b.week-a.week)/b.slope))))/b.spend (sql)

我的data.table解决方案

sum(salesx.y*(uplift.y*(1-exp(-((singledata.y*adstock.y)^(week.y-week.x)/slope.y))))/spend), by=list(spend, week.x)

但是,在添加'carryover'变量时,我无法使用非equi连接解决方​​案。

x.salesx*(uplift*(1.0-exp(-(`^`((singledata*carryover),x.week-week)/slope))))/i.spend

2 个答案:

答案 0 :(得分:3)

引入版本1.9.8(在2016年11月25日CRAN上)data.table 非equi连接,这有助于避免消耗内存的交叉连接:

library(data.table)
newdata4 <- 
  # coerce to data.table
  setDT(tempdata)[
    # non-equi self-join
    tempdata, on = .(spend, week > week), 
    # compute result
    .(calc3 = sum(salesx*(uplift*(1.0-exp(-(`^`(singledata,week-i.week)/slope))))/i.spend)), 
    # grouped by join parameters
    by = .EACHI][
      # replace NA
      is.na(calc3), calc3 := 0.0][]

# check that results are equal
all.equal(newdata, as.data.frame(newdata4[order(spend, week)]))
[1] TRUE

基准

OP使用交叉联接提供了three different个解决方案,两个sqldf变体和一个data.table方法。这些与非等连接进行比较。

以下代码

dt_tempdata <- data.table(tempdata)
microbenchmark::microbenchmark(
  sqldf = {
    newdata <- sqldf("select a.spend, a.week,
                 sum(case when b.week > a.week
                     then b.salesx*(b.uplift*(1-exp(-(power(b.singledata,b.week-a.week)/b.slope))))/b.spend
                     else 0.0 end) as calc3
                     from tempdata a, tempdata b  
                     where a.spend = b.spend 
                     group by a.spend,a.week")
  },
  sqldf_idx = {
    newdata2 <- sqldf(c('create index newindex on tempdata(spend)',
                        'select a.spend, a.week,
                        sum(case when b.week > a.week
                        then b.salesx*(b.uplift*(1-exp(-(power(b.singledata,b.week-a.week)/b.slope))))/b.spend
                        else 0.0 end) as calc3
                        from main.tempdata a left join main.tempdata b  
                        on a.spend = b.spend 
                        group by a.spend,a.week'), dbname = tempfile())
  },
  dt_merge = { 
    newdata3 <- merge(dt_tempdata, dt_tempdata, by="spend", all=TRUE, allow.cartesian=TRUE)[
      week.y > week.x, 
      .(calc3 = sum(salesx.y*(uplift.y*(1-exp(-(singledata.y^(week.y-week.x)/slope.y)))))), 
      by=.(spend, week.x)]
  },
  dt_nonequi = {
    newdata4 <- dt_tempdata[
      dt_tempdata, on = .(spend, week > week), 
      .(calc3 = sum(salesx*(uplift*(1.0-exp(-(`^`(singledata,week-i.week)/slope))))/i.spend)), 
      by = .EACHI][is.na(calc3), calc3 := 0.0]
  },
  times = 3L
)

返回这些时间:

Unit: seconds
       expr       min        lq      mean    median        uq       max neval cld
      sqldf  9.456110 10.081704 10.647193 10.707299 11.242735 11.778171     3   b
  sqldf_idx 10.980590 11.477774 11.734239 11.974958 12.111064 12.247170     3   b
   dt_merge  3.037857  3.147274  3.192227  3.256692  3.269412  3.282131     3  a 
 dt_nonequi  1.768764  1.776581  1.792359  1.784397  1.804156  1.823916     3  a

对于给定的问题大小,非equi连接速度最快,几乎是合并/交叉连接data.table方法的两倍,比sqldf代码快6倍。有趣的是,索引创建和/或临时文件使用在我的系统上似乎相当昂贵。

请注意,我已经简化了OP data.table解决方案。

最后,除合并/交叉连接(我已经避免修复此版本)之外的所有版本都返回相同的结果。

all.equal(newdata, newdata2) # TRUE
all.equal(newdata, as.data.frame(newdata3[order(spend, week.x)])) # FALSE (last week missing)
all.equal(newdata, as.data.frame(newdata4[order(spend, week)])) # TRUE

问题规模较大

OP报告说,合并/交叉连接data.table解决方案的内存耗尽了1M行的生产数据集。为了验证非equi连接方法消耗更少的内存,我测试了它的问题大小为5 M行(nrow(tempdata)),这比以前的基准测试运行大十倍。在我的具有8 GB内存的PC上,运行在大约18秒内完成没有问题。

Unit: seconds
       expr      min       lq     mean   median       uq      max neval
 dt_nonequi 18.12387 18.12657 18.23454 18.12927 18.28987 18.45047     3

答案 1 :(得分:1)

终于有时间再次调查一下:

我原来的解决方案:

  fields = ('id', 'title', 'description')

使用索引(虽然有些东西告诉我这不能正常工作):

  system.time(newdata <- sqldf("select a.spend, a.week,
                   sum(case when b.week > a.week
                   then b.salesx*(b.uplift*(1-exp(-(power(b.singledata,b.week-a.week)/b.slope))))/b.spend
                   else 0.0 end) as calc3
                   from tempdata a, tempdata b  
                   where a.spend = b.spend 
                   group by a.spend,a.week"))

   user  system elapsed 
  11.99    3.77   16.11 

Data.table解决方案(不会从sql中的ifelse语句返回0):

system.time(newdata2 <- sqldf(c('create index newindex on tempdata(spend)',
                                    'select a.spend, a.week,
                                    sum(case when b.week > a.week
                                    then b.salesx*(b.uplift*(1-exp(-(power(b.singledata,b.week-a.week)/b.slope))))/b.spend
                                    else 0.0 end) as calc3
                                    from main.tempdata a left join main.tempdata b  
                                    on a.spend = b.spend 
                                    group by a.spend,a.week'), dbname = tempfile()))

   user  system elapsed 
  12.73    2.93   15.76 

基于sql的解决方案的优点在于,因为临时输出存储在sql server而不是内存中,所以我不会遇到麻烦的'无法分配矢量'问题,这会发生在data.table / dplyr解决方案(当我添加更多数据时)...缺点是运行需要更长的时间。