R中每个组的条件累积平均值

时间:2015-09-14 14:43:56

标签: r dplyr mean

我有一个如下所示的数据集:

id   a   b
1    AA  2
1    AB  5
1    AA  1
2    AB  2
2    AB  4
3    AB  4
3    AB  3
3    AA  1

我需要计算每组中每条记录的累积平均值,并排除a == 'AA'的情况,因此样本输出应为:

id   a   b  mean
1    AA  2   -
1    AB  5   5
1    AA  1   5
2    AB  2   2
2    AB  4   (4+2)/2
3    AB  4   4
3    AB  3   (4+3)/2
3    AA  1   (4+3)/2
3    AA  4   (4+3)/2

我尝试使用dplyr和cummean通过获取错误来实现它。

df <- df %>%
       group_by(id) %>%
       mutate(mean = cummean(b[a != 'AA']))
  

错误:大小不一致(123),期望147(组大小)或1

你能否建议一种更好的方法来实现R?

2 个答案:

答案 0 :(得分:3)

The trick here is to reconstruct the cummean by dividing the adjusted cumsum by the adjusted count. As a one-liner:

df %>% group_by(id) %>% mutate(cumsum(b * (a != 'AA')) / cumsum(a != 'AA'))

We can make this a little nicer (the "multiply by a!='AA' - magic!" is the ugliness in my mind) by taking out the a != 'AA' as a column

df %>%
    group_by(id) %>%
    mutate(relevance = 0+(a!='AA'), 
           mean = cumsum(relevance * b) / cumsum(relevance))

答案 1 :(得分:2)

可能有一种更简单的方法。在这里,我们按'id'分组。首先将“a”中与“a”对应的元素转换为NAb*NA^(a=='AA')),然后创建新列“均值”。 NA^(a=='AA')为'a'中的'AA'输出NA,为所有其他值输出1。因此,当我们乘以'b'时,它用'b'中的值替换1,而NA保持原样。我们使用na.aggregate将'NA'替换为每组中非{NA}元素的mean,然后用cummean换行以获得累积均值。如果“a”的每个组中的第一个值为“AA”,我们可以通过乘以NA得到NA^(row_number()==1 & a=='AA')

library(zoo)
library(dplyr)
df %>% 
   group_by(id) %>% 
   mutate(Mean= cummean(na.aggregate(b*NA^(a=='AA')))*
                 NA^(row_number()==1 & a=='AA'))
# Source: local data frame [9 x 4]
#Groups: id [3]

#      id     a     b  Mean
#   (int) (chr) (int) (dbl)
#1     1    AA     2    NA
#2     1    AB     5   5.0
#3     1    AA     1   5.0
#4     2    AB     2   2.0
#5     2    AB     4   3.0
#6     3    AB     4   4.0
#7     3    AB     3   3.5
#8     3    AA     1   3.5
#9     3    AA     4   3.5

数据

df <- structure(list(id = c(1L, 1L, 1L, 2L, 2L, 3L, 3L, 3L, 3L), 
a = c("AA", 
"AB", "AA", "AB", "AB", "AB", "AB", "AA", "AA"), b = c(2L, 5L, 
1L, 2L, 4L, 4L, 3L, 1L, 4L)), .Names = c("id", "a", "b"),
class = "data.frame", row.names = c(NA, -9L))