我有一个包含两列x
和y
的二维数据集。输入新数据时,我想动态地获得线性回归系数和截距。使用scikit-learn,我可以像这样计算所有当前可用数据:
from sklearn.linear_model import LinearRegression
regr = LinearRegression()
x = np.arange(100)
y = np.arange(100)+10*np.random.random_sample((100,))
regr.fit(x,y)
print(regr.coef_)
print(regr.intercept_)
但是,我有一个很大的数据集(总共超过1万行),我想计算系数并在有新行进入时尽可能快地进行拦截。目前计算1万行大约需要600微秒,我想加快这一过程。
Scikit-learn似乎没有线性回归模块的在线更新功能。有更好的方法吗?
答案 0 :(得分:1)
我从本文中找到了解决方案:updating simple linear regression。实现如下:
def lr(x_avg,y_avg,Sxy,Sx,n,new_x,new_y):
"""
x_avg: average of previous x, if no previous sample, set to 0
y_avg: average of previous y, if no previous sample, set to 0
Sxy: covariance of previous x and y, if no previous sample, set to 0
Sx: variance of previous x, if no previous sample, set to 0
n: number of previous samples
new_x: new incoming 1-D numpy array x
new_y: new incoming 1-D numpy array x
"""
new_n = n + len(new_x)
new_x_avg = (x_avg*n + np.sum(new_x))/new_n
new_y_avg = (y_avg*n + np.sum(new_y))/new_n
if n > 0:
x_star = (x_avg*np.sqrt(n) + new_x_avg*np.sqrt(new_n))/(np.sqrt(n)+np.sqrt(new_n))
y_star = (y_avg*np.sqrt(n) + new_y_avg*np.sqrt(new_n))/(np.sqrt(n)+np.sqrt(new_n))
elif n == 0:
x_star = new_x_avg
y_star = new_y_avg
else:
raise ValueError
new_Sx = Sx + np.sum((new_x-x_star)**2)
new_Sxy = Sxy + np.sum((new_x-x_star).reshape(-1) * (new_y-y_star).reshape(-1))
beta = new_Sxy/new_Sx
alpha = new_y_avg - beta * new_x_avg
return new_Sxy, new_Sx, new_n, alpha, beta, new_x_avg, new_y_avg
性能比较:
Scikit学习版本,可总共计算1万个样本。
from sklearn.linear_model import LinearRegression
x = np.arange(10000).reshape(-1,1)
y = np.arange(10000)+100*np.random.random_sample((10000,))
regr = LinearRegression()
%timeit regr.fit(x,y)
# 419 µs ± 14.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
我的版本假定已经计算了9k个样本:
Sxy, Sx, n, alpha, beta, new_x_avg, new_y_avg = lr(0, 0, 0, 0, 0, x.reshape(-1,1)[:9000], y[:9000])
new_x, new_y = x.reshape(-1,1)[9000:], y[9000:]
%timeit lr(new_x_avg, new_y_avg, Sxy,Sx,n,new_x, new_y)
# 38.7 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
快10倍,这是预期的。