找出数据集中哪些要素共线

时间:2019-02-05 20:22:01

标签: python python-3.x machine-learning statistics statsmodels

我建立了一个基于多个特征来预测房屋价格的模型。

import statsmodels.api as statsmdl
from sklearn import datasets

X = data[['NumberofRooms', 'YearBuilt','Type','NewConstruction']
y = data["Price"]

model = statsmdl.OLS(y, X).fit()
predictions = model.predict(X)
model.summary()

如何找出这些特征中的哪些是共线的?

1 个答案:

答案 0 :(得分:3)

您可以使用DataFrame.corr()方法。

演示:

In [27]: df = pd.DataFrame(np.random.randint(10, size=(5,3)), columns=list('abc'))

In [28]: df['d'] = df['a'] * 10 - df['b'] / np.pi

In [29]: df['e'] = np.log(df['c'] **2)

In [30]: c = df.corr()

In [31]: c
Out[31]:
          a         b         c         d         e
a  1.000000  0.734858  0.113787  0.999837  0.067358
b  0.734858  1.000000 -0.523635  0.722485 -0.598739
c  0.113787 -0.523635  1.000000  0.129945  0.984257
d  0.999837  0.722485  0.129945  1.000000  0.084615
e  0.067358 -0.598739  0.984257  0.084615  1.000000

In [32]: c[c >= 0.7]
Out[32]:
          a         b         c         d         e
a  1.000000  0.734858       NaN  0.999837       NaN
b  0.734858  1.000000       NaN  0.722485       NaN
c       NaN       NaN  1.000000       NaN  0.984257
d  0.999837  0.722485       NaN  1.000000       NaN
e       NaN       NaN  0.984257       NaN  1.000000

In [33]: c[c >= 0.7].stack().reset_index(name='cor').query("abs(cor) < 1.0")
Out[33]:
   level_0 level_1       cor
1        a       b  0.734858
2        a       d  0.999837
3        b       a  0.734858
5        b       d  0.722485
7        c       e  0.984257
8        d       a  0.999837
9        d       b  0.722485
11       e       c  0.984257