计算大熊猫滚动交叉口的大小

时间:2014-03-17 21:11:24

标签: python pandas

我有一个dataframe,其中包含群组标签(' B')和每个群组的元素(' A')。组标签是有序的,我想知道组i中有多少元素出现在组i + 1中。

一个例子:

df= pd.DataFrame({ 'A': ['a','b','c','a','c','a','d'], 'B' : [1,1,1,2,2,3,3]})

   A  B
0  a  1
1  b  1
2  c  1
3  a  2
4  c  2
5  a  3
6  d  3

所需的输出类似于:

B
1  NaN
2  2
3  1

解决这个问题的一种方法是计算组I和组i + 1的并集中的不同元素的数量,然后减去每个组中不同元素的数量。我试过了:

pd.rolling_apply(grp['A'], lambda x: len(x.unique()),2)

但这会产生错误:

AttributeError: 'Series' object has no attribute 'type'

如何让它与rolling_apply一起使用,还是有更好的方法来解决这个问题?

1 个答案:

答案 0 :(得分:1)

使用集合并移动结果的方法:

首先对数据帧进行分组,然后将每个组的A列转换为集合:

In [86]: grp = df.groupby('B')
In [87]: s = grp.apply(lambda x : set(x['A']))
In [88]: s
Out[88]: 
B
1    set([a, c, b])
2       set([a, c])
3       set([a, d])
dtype: object

要计算连续集之间的交集,请进行移位版本(我将NaN替换为空集以进行下一步):

In [89]: s2 = s.shift(1).fillna(set([]))
In [90]: s2
Out[90]: 
B
1           set([])
2    set([a, c, b])
3       set([a, c])
dtype: object

合并两个系列并计算交叉点的长度:

In [91]: s.combine(s2, lambda x, y: len(x.intersection(y)))
Out[91]: 
B
1    0
2    2
3    1
dtype: object

执行最后一步的另一种方法(对于集&意味着intersection):

df = pd.concat([s, s2], axis=1)
df.apply(lambda x: len(x[0] & x[1]), axis=1)

滚动应用不起作用的原因是1)你提供了一个GroupBy对象而不是一个系列,2)它只适用于数值。

相关问题