data.table在因子列上缓慢聚合

时间:2016-09-04 19:44:35

标签: r data.table

今天遇到了这个问题。我有一个<Embed-Dependency>stanford-corenlp</Embed-Dependency> 带有一些分类字段(即因子)。像

这样的东西
data.table

现在,我想为set.seed(2016) dt <- data.table( ID=factor(sample(30000, 2000000, replace=TRUE)), Letter=factor(LETTERS[sample(26, 2000000, replace=TRUE)]) ) dt ID Letter 1: 5405 E 2: 4289 E 3: 25250 J 4: 4008 J 5: 14326 G --- 的每一列计算gini impurity,按ID中的值进行分组。

我的尝试:

dt

这有效,但速度极慢。似乎如果我将ID从因子更改为数字,它运行得更快。这是我在实践中应该做的还是加速这种操作的方法不那么简单?

另外,我知道没有必要计算自己分组的ID的基尼杂质,但请仔细看看。我的真实数据集具有更多的分类功能,这些功能会增加速度。

另请注意,我使用的是data.table版本1.9.7(开发)

修改

对不起伙计......我刚刚意识到,当我用ID作为数字而不是因素来测试时,我对giniImpurity <- function(vals){ # Returns the gini impurity of a set of categorical values # vals can either be the raw category instances (vals=c("red", "red", "blue", "green")) or named category frequencies (vals=c(red=2, blue=1, green=1)) # Gini Impurity is the probability a value is incorrectly labeled when labeled according to the distribution of classes in the set if(is(vals, "numeric")) counts <- vals else counts <- table(vals) total <- sum(counts) return(sum((counts/total)*(1-counts/total))) } # Calculate gini impurities dt[, list(Samples=.N, ID.GinitImpurity=giniImpurity(ID), Letter.GiniImpurity=giniImpurity(Letter)), by=ID] ID Samples ID.GinitImpurity Letter.GiniImpurity 1: 5405 66 0 0.9527 2: 4289 73 0 0.9484 3: 25250 60 0 0.9394 4: 4008 66 0 0.9431 5: 14326 79 0 0.9531 --- 的调用是由于其工作原理的加速而发生的。我想对giniImpurity()的呼叫是减速的地方。仍然不是100%确定如何更快地做到这一点。

1 个答案:

答案 0 :(得分:1)

知道了。

giniImpurities <- function(dt){
  # Returns pairs of categorical fields (cat1, cat2, GI) where GI is the weighted gini impurity of 
  # cat2 relative to the groups determined by cat1

  #--------------------------------------------------
  # Subset dt by just the categorical fields

  catfields <- colnames(dt)[sapply(dt, is.factor)]
  cats1 <- dt[, catfields, with=FALSE]

  # Build a table to store the results
  varpairs <- CJ(Var1=catfields, Var2=catfields)
  varpairs[Var1==Var2, GI := 0]

  # Loop through each grouping variable
  for(catcol in catfields){
    print(paste("Calculating gini impurities by field:", catcol))

    setkeyv(cats1, catcol)
    impuritiesDT <- cats1[, list(Samples=.N), keyby=catcol]

    # Looop through each of the other categorical columns
    for(colname in setdiff(catfields, catcol)){

      # Get the gini impurity for each pair (catcol, other)
      counts <- cats1[, list(.N), by=c(catcol, colname)]
      impurities <- counts[, list(GI=sum((N/sum(N))*(1-N/sum(N)))), by=catcol]
      impuritiesDT[impurities, GI := GI]
      setnames(impuritiesDT, "GI", colname)
    }

    cats1.gini <- melt(impuritiesDT, id.vars=c(catcol, "Samples"))
    cats1.gini <- cats1.gini[, list(GI=weighted.mean(x=value, w=Samples)), by=variable]
    cats1.gini <- cats1.gini[, list(Var1=catcol, Var2=variable, GI)]
    varpairs[cats1.gini, `:=`(GI=i.GI), on=c("Var1", "Var2")]
  }

  return(varpairs[])
}

giniImpurities(dt)
      Var1    Var2        GI
1:  Letter  Letter 0.0000000
2:  Letter Letter2 0.9615258
3:  Letter  PGroup 0.9999537
4: Letter2  Letter 0.9615254
5: Letter2 Letter2 0.0000000
6: Letter2  PGroup 0.9999537
7:  PGroup  Letter 0.9471393
8:  PGroup Letter2 0.9470965
9:  PGroup  PGroup 0.0000000