如何使用`lmplot`绘制线性回归而不拦截?

时间:2016-01-11 15:57:44

标签: python linear-regression seaborn

seaborn中的lmplot符合回归模型并具有拦截。但是,有时我想拟合回归模型而不拦截,即通过原点进行回归。

例如:

In [1]: import numpy as np
   ...: import pandas as pd
   ...: import seaborn as sns
   ...: import matplotlib.pyplot as plt
   ...: import statsmodels.formula.api as sfa
   ...: 

In [2]: %matplotlib inline
In [3]: np.random.seed(2016)
In [4]: x = np.linspace(0, 10, 32)
In [5]: y = 0.3 * x + np.random.randn(len(x))
In [6]: df = pd.DataFrame({'x': x, 'y': y})
In [7]: r = sfa.ols('y ~ x + 0', data=df).fit()
In [8]: sns.lmplot(x='x', y='y', data=df, fit_reg=True)
Out[8]: <seaborn.axisgrid.FacetGrid at 0xac88a20>

enter image description here

这个数字我想要的是什么:

In [9]: fig, ax = plt.subplots(figsize=(5, 5))
   ...: ax.scatter(x=x, y=y)
   ...: ax.plot(x, r.fittedvalues)
   ...: 
Out[9]: [<matplotlib.lines.Line2D at 0x5675a20>]

enter image description here

3 个答案:

答案 0 :(得分:1)

seaborn API不允许直接更改线性回归模型。

呼叫链为:

  • 在某个时间点_RegressionPlotter.plot()被调用以产生情节
  • 调用_RegressionPlotter.lineplot()进行拟合图
  • 本身称为regression模块中的fit_regression
  • 在您的情况下,这反过来又调用了许多深层次的回归方法,例如self.fit_fast(grid)

要使用其他回归模型,您可以:

  • 猴子修补_RegressionPlotter类并更改lineplot()行为
  • 猴子修补回归模块中的fit_regression()fit_fast()方法

要制作这样的海洋猴补丁,可以参考an answer I made a while ago that does the same type of hack。这不好,圣诞老人可能不开心。 这意味着您可以按预期目的动态修改seaborn。

当然,您需要实现自己的回归模型,才能在y = a * x而不是y = (a * x) + b中使用定律。 importanceofbeingernest已在评论this SO question中指出了此事。


一种不错的解决方法是建立自己的情节,但是您已经在自己的问题中回答了这一部分。

引用您自己的问题(我没有检查提供的代码):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.formula.api as sfa

np.random.seed(2016)
x = np.linspace(0, 10, 32)
y = 0.3 * x + np.random.randn(len(x))
df = pd.DataFrame({'x': x, 'y': y})
r = sfa.ols('y ~ x + 0', data=df).fit()
fig, ax = plt.subplots(figsize=(5, 5))
ax.scatter(x=x, y=y)
ax.plot(x, r.fittedvalues)

答案 1 :(得分:0)

如果仅出于显示目的,您可以通过调整数据的平均值来绕过和更改y-ticks。

您可以这样做:

sns.lmplot(x='x', y='y', data=df, fit_reg=True)
y_ticks = [int(round(ytick - np.mean(y), 0)) for ytick in plt.gca().get_yticks()]
plt.gca().set_yticklabels(y_ticks)

请注意,这不会改变生产线本身,也不会改变任何内部结构,只会改变现成的Vizualization。

答案 2 :(得分:-1)

这是否适合您的目的?

sns.lmplot(x='x', y='y', data=df, fit_reg=False)