如何在Python中找到合适的线性拟合?

时间:2015-11-12 14:04:05

标签: python numpy matplotlib scipy linear-regression

我正在尝试为大多数样本中具有线性行为的大量数据找到最合适的线性拟合。以原始形式绘制的数据(link)如下所示:

enter image description here

我需要包含大部分点的线性拟合,如下图中的粗橙线所示:

enter image description here

我尝试计算点的平均值但是如何使用Python提取线性区域?

可重现的代码

import matplotlib.pyplot as plt
import numpy as np
import itertools
from scipy import optimize


data = np.loadtxt('linear.dat', skiprows = 1, delimiter = '\t')
print data
x = data[:, 0]
y = data[:, 1:]
m = y.shape[0]
n = y.shape[1]


def linear_fit(x, a, b):
    return a * x + b

y_fit = np.empty(shape=(m, n))


for i in range(n):
    fit_y_fit_a, fit_y_fit_b = optimize.curve_fit(linear_fit, x, y[:, i])[0]
    y_fit[:, i] = fit_y_fit_a * x + fit_y_fit_b

y[~np.isfinite(y)] = 0
y_mean = np.mean(y, axis = 1)

fig = plt.figure(figsize=(5, 5))
fig.clf()
plot_y_vs_x = plt.subplot(111)
markers = itertools.cycle(('o', '^', 's', 'v', 'h', '>', 'p', '<'))
for i in range(n):
    plot_y_vs_x.plot(x, y, linestyle = '', marker = markers.next(), alpha = 1, zorder = 2)
    # plot_y_vs_x.plot(x, y_fit, linestyle = ':', color = 'gray', linewidth = 0.5, zorder = 1)
    plot_y_vs_x.plot(x, y_mean, linestyle = '-', linewidth = 3.0, color = 'red', zorder = 3)
plot_y_vs_x.set_ylim([-10, 10])
plot_y_vs_x.set_ylabel('Y', labelpad = 6)
plot_y_vs_x.set_xlabel('X', labelpad = 6)
fig.savefig('plot.pdf')
plt.close()

2 个答案:

答案 0 :(得分:1)

您正在寻找计算积分的线性回归。为此,

import numpy as np
x = np.array([0, 1, 2, 3])
y = np.array([-1, 0.2, 0.9, 2.1])
A = np.vstack([x, np.ones(len(x))]).T
m, c = np.linalg.lstsq(A, y)[0]

这将为您提供适合y = mx + c的值m和c。只需将x和y替换为您自己的值作为numpy数组。

http://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html

答案 1 :(得分:1)

我建议使用polyfitpoly1d类:

polyfit给出最小二乘多项式最佳拟合,其中包含您选择的顺序(1表示线性)。

a=np.genfromtxt('linear.dat',skiprows=1)
x=a[:,0]
y=a[:,1:]
k = np.linspace(700,1000,50)
plt.clf()
for z in y.T:
    plt.scatter(x,z)
for z in y.T:
    fit = np.polyfit(x,z,2) # increase order to get better fit
    fit_fn = np.poly1d(fit)
    plt.plot(k,fit_fn(k))
plt.xlim(700,1000)

enter image description here

数据对我来说似乎不是很线性。但你可以使用前10个点:

k = np.linspace(700,900,50)
plt.clf()
plt.scatter(x,y[:,5]) # e.g. line 5
fit = np.polyfit(x[-10:],y[-10:,5],1) # increase order to get better fit
fit_fn = np.poly1d(fit)
plt.plot(k,fit_fn(k))
plt.xlim(700,1000)

enter image description here