我试图在Python中获得线性回归的诊断图,我想知道是否有快速的方法来做到这一点。
在R中,您可以使用下面的代码片段,其中包含残差与拟合图,正常Q-Q图,比例位置,残差与杠杆图。
m1 <- lm(cost~ distance, data = df1)
summary(m1)
plot(m1)
在python中有快速的方法吗?
这是一篇很棒的博客文章,描述了如何使用Python代码获得与R给你的相同的图,但它需要相当多的代码(至少与R方法相比)。链接:https://underthecurve.github.io/jekyll/update/2016/07/01/one-regression-six-ways.html#Python
答案 0 :(得分:0)
我更喜欢将所有内容存储在pandas
中,并在可能的情况下使用DataFrame.plot()
进行绘制:
from matplotlib import pyplot as plt
from pandas.core.frame import DataFrame
import scipy.stats as stats
import statsmodels.api as sm
def linear_regression(df: DataFrame) -> DataFrame:
"""Perform a univariate regression and store results in a new data frame.
Args:
df (DataFrame): orginal data set with x and y.
Returns:
DataFrame: another dataframe with raw data and results.
"""
mod = sm.OLS(endog=df['y'], exog=df['x']).fit()
influence = mod.get_influence()
res = df.copy()
res['resid'] = mod.resid
res['fittedvalues'] = mod.fittedvalues
res['resid_std'] = mod.resid_pearson
res['leverage'] = influence.hat_matrix_diag
return res
def plot_diagnosis(df: DataFrame):
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.style.use('seaborn')
# Residual against fitted values.
df.plot.scatter(
x='fittedvalues', y='resid', ax=axes[0, 0]
)
axes[0, 0].axhline(y=0, color='grey', linestyle='dashed')
axes[0, 0].set_xlabel('Fitted Values')
axes[0, 0].set_ylabel('Residuals')
axes[0, 0].set_title('Residuals vs Fitted')
# qqplot
sm.qqplot(
df['resid'], dist=stats.t, fit=True, line='45',
ax=axes[0, 1], c='#4C72B0'
)
axes[0, 1].set_title('Normal Q-Q')
# The scale-location plot.
df.plot.scatter(
x='fittedvalues', y='resid_std', ax=axes[1, 0]
)
axes[1, 0].axhline(y=0, color='grey', linestyle='dashed')
axes[1, 0].set_xlabel('Fitted values')
axes[1, 0].set_ylabel('Sqrt(|standardized residuals|)')
axes[1, 0].set_title('Scale-Location')
# Standardized residuals vs. leverage
df.plot.scatter(
x='leverage', y='resid_std', ax=axes[1, 1]
)
axes[1, 1].axhline(y=0, color='grey', linestyle='dashed')
axes[1, 1].set_xlabel('Leverage')
axes[1, 1].set_ylabel('Sqrt(|standardized residuals|)')
axes[1, 1].set_title('Residuals vs Leverage')
plt.tight_layout()
plt.show()
仍然缺少许多功能,但它提供了一个良好的开端。我在这里Access standardized residuals, cook's values, hatvalues (leverage) etc. easily in Python?
了解了如何提取影响力统计信息顺便说一下,有一个具有所有功能的软件包dynobo/lmdiag。