有效扩张的linregress

时间:2014-08-14 06:49:11

标签: python numpy pandas scipy

我有一个包含两列(X和Y坐标)的数据框。我需要从df开始的扩展线性回归。例如,在第二点,我需要前两点的回归;在第3点我需要前3分,依此类推。根据文档,对于具有一个输入和一个输出expanding_apply的函数,可以使用,但linregress提供5个输出。

目前我正在对所有行执行for循环,这样可以正常工作但不出所料非常慢,甚至几乎无法使用。

我尝试了几件事,但遭到拒绝。尝试将输入作为元组发送:

pd.expanding_apply((df.x, df.y), linregress)
> AttributeError: 'tuple' object has no attribute 'dtype'

尝试将输入发送为df:

pd.expanding_apply(df[['x','y']], linregress)
> IndexError: tuple index out of range

每个回归几乎与之前的回归相同(只有一个添加的数据点),因此可能会有大量的加速空间。有没有办法在熊猫或numpy / scipy域中实现这个更高效的方法?

编辑:linregress可选择接受2维数组(而不是2个独立的一维数组),因此linregress(df[['x','y']])本身可以正常工作。但是expanding_apply可能期望arg系列,而不是df。

1 个答案:

答案 0 :(得分:3)

要计算y = a*x + b的线性回归参数,您必须求解一个超定方程组X*a = y,其中:

X = [[1, x0], [1, x1], ..., [1, x(n-1)]]
a = [b, a]
y = [ y0, y1, ..., y(n-1)]

如果您仅使用ab的值,则可以将系统的两侧预乘X.T,并解决生成的2x2系统。仔细看看,得到的数组可以写成:

np.dot(X.T, X) = [[n, np.sum(x)],
                  [np.sum(x), np.sum(x*x)]]
np.dot(X.T, y) = [np.sum(y), np.sum(x*y)]

将所有这些放在一起,给定两个长度相等的1D阵列xy,您可以使用numpy> = 1.8执行以下操作:

n = 10
x, y = np.random.rand(2, n)

lhs = np.empty((n-1, 2, 2))
rhs = np.empty((n-1, 2))

lhs[:, 0, 0] = np.arange(2, n+1)
lhs[:, 0, 1] = np.cumsum(x)[1:]
lhs[:, 1, 0] = lhs[:, 0, 1]
lhs[:, 1, 1] = np.cumsum(x*x)[1:]

rhs[:, 0] = np.cumsum(y)[1:]
rhs[:, 1] = np.cumsum(x*y)[1:]

a = np.linalg.solve(lhs, rhs)

您可以检查a是否包含与polyfit的结果相对应的正确参数:

In [49]: a
Out[49]:
array([[ 0.64778976, -0.39918768],
       [ 0.76225593, -0.41054035],
       [ 0.72598372, -0.35430181],
       [ 0.70608159, -0.33873589],
       [ 0.6899674 , -0.34941498],
       [ 0.68270772, -0.34834723],
       [ 0.71031366, -0.59487271],
       [ 0.7422803 , -0.74757567],
       [ 0.65982282, -0.48593478]])

In [50]: for j in range(2, n+1):
   ....:     print np.polynomial.polynomial.polyfit(x[:j], y[:j], 1)
   ....:
[ 0.64778976 -0.39918768]
[ 0.76225593 -0.41054035]
[ 0.72598372 -0.35430181]
[ 0.70608159 -0.33873589]
[ 0.6899674  -0.34941498]
[ 0.68270772 -0.34834723]
[ 0.71031366 -0.59487271]
[ 0.7422803  -0.74757567]
[ 0.65982282 -0.48593478]