对data.table中的分组产生的每个矩阵求和

时间:2017-02-07 04:47:10

标签: r data.table

对于data.table的每个组,我创建一个矩阵。我想将所有这些矩阵添加到一个结果中。以下示例说明了更好的内容:

set.seed(1)
library(data.table)
A <- data.table(A = letters[1:3], B = rnorm(3))
fun <- function(dt){ matrix(rnorm(9),nrow = 3, ncol = 3)  }
A[,fun(.SD), by = A]

此输出是一个列向量,其中所有矩阵条目都已堆叠。我想恢复矩阵形式,或使用其他方法。

我想添加我使用的所有矩阵(所以我真的不介意使用,或者data.table,无论哪个给我答案):

by(A, A$A, fun)

A$A: a
           [,1]       [,2]      [,3]
[1,] -0.6264538  1.5952808 0.4874291
[2,]  0.1836433  0.3295078 0.7383247
[3,] -0.8356286 -0.8204684 0.5757814
-------------------------------------------------------------------------------------- 
A$A: b
           [,1]       [,2]        [,3]
[1,] -0.3053884 -0.6212406 -0.04493361
[2,]  1.5117812 -2.2146999 -0.01619026
[3,]  0.3898432  1.1249309  0.94383621
-------------------------------------------------------------------------------------- 
A$A: c
          [,1]        [,2]        [,3]
[1,] 0.8212212  0.78213630  0.61982575
[2,] 0.5939013  0.07456498 -0.05612874
[3,] 0.9189774 -1.98935170 -0.15579551

3 个答案:

答案 0 :(得分:2)

by的结果只是一个矩阵列表。因此,您只需添加列表的所有元素(请参阅How to sum a numeric list elements in R):

> set.seed(1)
> A <- data.table(A = letters[1:3], B = rnorm(3))
> fun <- function(dt){ matrix(rnorm(9),nrow = 3, ncol = 3)  }
> l <- by(A, A$A, fun)
> Reduce("+",l)
          [,1]      [,2]       [,3]
[1,]  1.756177 1.0623212 -0.9549196
[2,] -1.810627 0.6660057  1.6275324
[3,] -1.684889 1.3638221  1.7267622
> l[[1]] + l[[2]] + l[[3]]
          [,1]      [,2]       [,3]
[1,]  1.756177 1.0623212 -0.9549196
[2,] -1.810627 0.6660057  1.6275324
[3,] -1.684889 1.3638221  1.7267622

答案 1 :(得分:2)

如果您想要data.tablelist嵌套到.(),您可以留在j内:

A[,.(mats=.(fun(.SD))), by = A][, Reduce(`+`, mats)]
#          [,1]      [,2]       [,3]
#[1,]  1.756177 1.0623212 -0.9549196
#[2,] -1.810627 0.6660057  1.6275324
#[3,] -1.684889 1.3638221  1.7267622

答案 2 :(得分:1)

这是tidyverse方法

library(tidyverse)
A %>%
     split(.$A) %>% 
     map(fun) %>%
     reduce(`+`)
#        [,1]       [,2]       [,3]
#[1,] -2.9614395  3.9622030  2.0717522
#[2,]  2.9163822  0.7773892  0.3030291
#[3,]  0.4163046 -1.1495702 -1.4320626

注意:输出值的差异取决于所选的seed,即

set.seed(24)
Reduce(`+`, by(A, A$A, fun))
#            [,1]       [,2]       [,3]
#[1,]  0.06642572 -1.9509985 -0.6730669
#[2,] -0.26398712 -2.2912755 -0.8955920
#[3,]  0.94358370  0.3295733  0.6023412

set.seed(24)
A %>%
    split(.$A) %>%
    map(fun) %>% 
    reduce(`+`)
#          [,1]       [,2]       [,3]
#[1,]  0.06642572 -1.9509985 -0.6730669
#[2,] -0.26398712 -2.2912755 -0.8955920
#[3,]  0.94358370  0.3295733  0.6023412