我有一个DataFrame a
和系列b
。我希望找到a
到b
的每列的条件相关性,条件是b
的值。具体来说,我使用pd.cut
将b
分解为5个组。但是,我使用b
高于或低于均值的标准偏差而不是标准分位数。
np.random.seed(123)
a = (pd.DataFrame(np.random.randn(1000,3))
.add_prefix('col'))
b = pd.Series(np.random.randn(1000))
mu, sigma = b.mean(), b.std()
breakpoints = mu + np.array([-2., -1., 1., 2.]) * sigma
breakpoints = np.append(np.insert(breakpoints, 0, -np.inf), np.inf)
# There are now 6 breakpoints to create 5 groupings:
# array([ -inf, -1.91260048, -0.9230609 , 1.05601827, 2.04555785,
# inf])
labels = ['[-inf,-2]', '(-2,-1]', '(-1,1]', '(1,2]', '(2,inf]']
groups = pd.cut(b, bins=breakpoints, labels=labels)
这里一切都很好。我使用.corrwith
与.groupby
一起挂在最后一行,这会引发ValueError
:
a.groupby(groups).corrwith(b.groupby(groups))
有什么想法吗? a.corrwith(b)
的结果是一个系列,所以我认为这里的结果应该是一个以组/桶为列的DataFrame。例如,一列将是:
print(a[b < breakpoints[1]].corrwith(b[b < breakpoints[1]]))
# Correlation conditional on that `b` is [-inf, -2 stdev]
col0 0.43708
col1 -0.08440
col2 -0.02923
dtype: float64
答案 0 :(得分:0)
一种功能性但不漂亮的解决方案:
full = a.join(b.to_frame(name='_drop'))
corrs = (full.groupby(groups)
.corr()
.loc[(slice(None), a.columns), '_drop']
.unstack()
.T)
print(corrs)
[-inf,-2] (-2,-1] (-1,1] (1,2] (2,inf]
col0 0.43708 0.06716 0.02437 0.01695 0.05384
col1 -0.08440 0.04208 0.05529 -0.07146 0.14766
col2 -0.02923 -0.19672 0.01519 -0.02290 -0.17101