Pandas / Statsmodel OLS预测未来价值

时间:2014-08-26 19:58:53

标签: python pandas linear-regression statsmodels

我一直试图在我创建的模型中预测未来的值。我在熊猫和statsmodels中尝试了两种OLS。以下是我在statsmodels中的内容:

import statsmodels.api as sm
endog = pd.DataFrame(dframe['monthly_data_smoothed8'])
smresults = sm.OLS(dframe['monthly_data_smoothed8'], dframe['date_delta']).fit()
sm_pred = smresults.predict(endog)
sm_pred

返回的数组长度等于原始数据帧中的记录数,但值不相同。当我使用pandas执行以下操作时,我没有返回任何值。

from pandas.stats.api import ols
res1 = ols(y=dframe['monthly_data_smoothed8'], x=dframe['date_delta'])
res1.predict

(请注意,Pandas中的OLS没有.fit功能)有人可以了解我如何从我的OLS模型中获得未来的预测 - 无论是pandas还是statsmodel - 我意识到我一定不能正确使用.predict我已经阅读了人们遇到的其他多个问题,但它们似乎并不适用于我的案例。

编辑我认为定义的'endog'是不正确的 - 我应该传递我想要预测的值;因此,我创建了一个超过最后记录值的12个期间的日期范围。但是我仍然错过了一些错误:

matrices are not aligned

编辑这里是一段数据,数字的最后一列(红色)是日期增量,是与第一个日期相差几个月的数据:

month   monthly_data    monthly_data_smoothed5  monthly_data_smoothed8  monthly_data_smoothed12 monthly_data_smoothed3  date_delta
0   2011-01-31  3.711838e+11    3.711838e+11    3.711838e+11    3.711838e+11    3.711838e+11    0.000000
1   2011-02-28  3.776706e+11    3.750759e+11    3.748327e+11    3.746975e+11    3.755084e+11    0.919937
2   2011-03-31  4.547079e+11    4.127964e+11    4.083554e+11    4.059256e+11    4.207653e+11    1.938438
3   2011-04-30  4.688370e+11    4.360748e+11    4.295531e+11    4.257843e+11    4.464035e+11    2.924085

1 个答案:

答案 0 :(得分:6)

我认为你的问题是statsmodels默认情况下不会添加拦截,因此你的模型并没有达到很大的效果。在你的代码中解决它将是这样的:

dframe = pd.read_clipboard() # your sample data
dframe['intercept'] = 1
X = dframe[['intercept', 'date_delta']]
y = dframe['monthly_data_smoothed8']

smresults = sm.OLS(y, X).fit()

dframe['pred'] = smresults.predict()

另外,对于它的价值,我认为在处理DataFrames时,statsmodel公式api更好用,并且默认添加一个拦截(添加- 1删除)。见下文,它应该给出相同的答案。

import statsmodels.formula.api as smf

smresults = smf.ols('monthly_data_smoothed8 ~ date_delta', dframe).fit()

dframe['pred'] = smresults.predict()

编辑:

要预测未来的值,只需将新数据传递给.predict()例如,使用第一个模型:

In [165]: smresults.predict(pd.DataFrame({'intercept': 1, 
                                          'date_delta': [0.5, 0.75, 1.0]}))
Out[165]: array([  2.03927604e+11,   2.95182280e+11,   3.86436955e+11])

在截距上 - 在1数字中没有编码只是基于OLS的数学(截距完全类似于总是等于1的回归量),所以您可以从摘要中提取值。查看statsmodels docs,添加拦截的另一种方法是:

X = sm.add_constant(X)