清晰的方法来计算R中data.table的两列之间的转换概率

时间:2016-03-11 14:59:43

标签: r data.table probability

玩具示例:

library(data.table)

set.seed(1)
n_people <- 100
groups <- c("A", "B", "C")
example_table <- data.table(person_id=seq_len(n_people),
                            group_2010=sample(groups, n_people, TRUE),
                            group_2011=sample(groups, n_people, TRUE))

## Error-prone and requires lots of typing -- programmatic alternative?
transition_probs <- example_table[, list(pr_A_2011=mean(group_2011=="A"),
                                         pr_B_2011=mean(group_2011=="B"),
                                         pr_C_2011=mean(group_2011=="C")),
                                         by=group_2010]
transition_probs  # Essentially a transition matrix giving Pr[group_2011 | group_2010]

#    group_2010 pr_A_2011 pr_B_2011 pr_C_2011
# 1:          A 0.1481481 0.5185185 0.3333333
# 2:          B 0.3684211 0.3947368 0.2368421
# 3:          C 0.3142857 0.3142857 0.3714286

当组是A,B,C时,上面的“手动”方法很好,但是如果有更多组(或者我们只有groups向量但是提前不知道,则会变得混乱)它包含什么)。

在上面的示例代码中是否有“data.table方法”来计算transition_probs对象?列表(pr_A_2011 = ...)可以用程序化的东西替换吗?

我担心的是,如果我添加了一个组D,我将不得不在多个位置编辑代码,特别是键入pr_D_2011=mean(group_2011=="D")

3 个答案:

答案 0 :(得分:3)

我愿意

lvls = example_table[, sort(unique(c(group_2010, group_2011))) ]
x = dcast(example_table, group_2010~group_2011)[, N := Reduce(`+`,.SD), .SDcols=lvls]

#    group_2010  A  B  C  N
# 1:          A  6  9 15 30
# 2:          B 15  4 12 31
# 3:          C 11 11 17 39

从这里开始,如果你想要转换概率,只需除以N

x[, (lvls) := lapply(.SD,`/`, x$N), .SDcols=lvls]
# or, with data.table 1.9.7+
x[, (lvls) := lapply(.SD,`/`, N), .SDcols=lvls]

#    group_2010         A         B         C  N
# 1:          A 0.1481481 0.5185185 0.3333333 27
# 2:          B 0.3684211 0.3947368 0.2368421 38
# 3:          C 0.3142857 0.3142857 0.3714286 35

答案 1 :(得分:2)

data.table的设计有意与data.frames上的操作兼容,因此,除非您能(a)证明此操作是一个巨大的瓶颈,并且(b)证明替代解决方案是显着的更快,为什么不坚持简洁明了:

prop.table(table(example_table[,2:3,with=FALSE]),1)
          group_2011
group_2010         A         B         C
         A 0.1481481 0.5185185 0.3333333
         B 0.3684211 0.3947368 0.2368421
         C 0.3142857 0.3142857 0.3714286

答案 2 :(得分:1)

我认为目前的答案都能很好地解决你的问题。我会回答然后以更通用的方式处理它 如果您想要真正的编程能力,可以使用computing on the language R语言功能。

  

R属于一类编程语言,其中子程序能够修改或构造其他子程序,并将结果作为语言本身不可或缺的一部分进行评估。

library(data.table)
set.seed(1)
n_people <- 100
groups <- c("A", "B", "C")
example_table <- data.table(person_id=seq_len(n_people),
                            group_2010=sample(groups, n_people, TRUE),
                            group_2011=sample(groups, n_people, TRUE))
f = function(data, groups, years) {
    stopifnot(is.data.table(data), length(groups) > 0L, length(years) == 2L, paste0("group_", years) %in% names(data))
    j.names = sprintf("pr_%s_%s", c(groups), years[2L])
    j.vals = lapply(setNames(groups, j.names), function(group) call("mean", call("==", as.name(sprintf("group_%s", years[2L])), group)))
    jj = as.call(c(list(as.name(".")), j.vals))
    data[, eval(jj), by = c(sprintf("group_%s", years[1L]))]
}
f(example_table, groups, 2010:2011)
#   group_2010 pr_A_2011 pr_B_2011 pr_C_2011
#1:          A 0.1481481 0.5185185 0.3333333
#2:          B 0.3684211 0.3947368 0.2368421
#3:          C 0.3142857 0.3142857 0.3714286

无需在少数地方替换代码,只需将参数传递给函数。