为什么线性回归在使用PYMC3的训练中如此差

时间:2018-03-31 21:02:00

标签: linear-regression pymc3

我是PYMC3的新手。也许这是一个天真的问题,但我搜索了很多,并没有找到任何解释这个问题。 基本上,我想在PYMC3中进行线性回归,但训练非常缓慢,训练集的模型性能也非常差。以下是我的代码:

this.input.getImages()
  .throttleTime(500)
  .switchMap(image =>

      new Observable<{}>(observer => {
        // Some operation
      })
      .map(i => ({ image, i }))

  ).filter(({ i }) =>  {
    // some filtering
  })
  .map(({ image }) => image)
  .take(6)
  .bufferCount(6)
  .map(images =>  // switch map??
    Observable.fromPromise(this.server.validate(images))
  )
  .retry(2)  // This only retrys the request, I want it to retry the whole chain (to get valid images)
  .subscribe(images => {
      console.log('All done')
    },
    err => {console.log(err)}
  )

enter image description here

在这里,我使用PYMC3来训练模型:

X_Tr = np.array([ 13.99802212,  13.8512075 ,  13.9531636 ,  13.97432944,
    13.89211468,  13.91357953,  13.95987483,  13.86476587,
    13.9501789 ,  13.92698143,  13.9653932 ,  14.06663115,
    13.91697969,  13.99629862,  14.01392784,  13.96495713,
    13.98697998,  13.97516973,  14.01048397,  14.05918188,
    14.08342002,  13.89350606,  13.81768849,  13.94942447,
    13.90465027,  13.93969029,  14.18771189,  14.08631113,
    14.03718829,  14.01836206,  14.06758363,  14.05243539,
    13.96287123,  13.93011351,  14.01616973,  14.01923812,
    13.97424024,  13.9587175 ,  13.85669845,  13.97778302,
    14.04192138,  13.93775494,  13.86693585,  13.79985956,
    13.82679677,  14.06474544,  13.90821822,  13.71648423,
    13.78899668,  13.76857337,  13.87201756,  13.86152949,
    13.80447525,  13.99609891,  14.0210165 ,  13.986906  ,
    13.97479211,  14.04562055,  14.03293095,  14.15178043,
    14.32413197,  14.2330354 ,  13.99247751,  13.92962912,
    13.95394525,  13.87888254,  13.82743111,  14.10724699,
    14.23638905,  14.15731881,  14.13239278,  14.13386722,
    13.91442452,  14.01056255,  14.19378649,  14.22233852,
    14.30405399,  14.25880108,  14.23985258,  14.21184303,
    14.4443183 ,  14.55710331,  14.42102092,  14.29047616,
    14.43712609,  14.58666212])
y_tr = np.array([ 13.704,  13.763,  13.654,  13.677,  13.66 ,  13.735,  13.845,
    13.747,  13.747,  13.606,  13.819,  13.867,  13.817,  13.68 ,
    13.823,  13.779,  13.814,  13.936,  13.956,  13.912,  13.982,
    13.979,  13.919,  13.944,  14.094,  13.983,  13.887,  13.902,
    13.899,  13.881,  13.784,  13.909,  13.99 ,  14.06 ,  13.834,
    13.778,  13.703,  13.965,  14.02 ,  13.992,  13.927,  14.009,
    13.988,  14.022,  13.754,  13.837,  13.91 ,  13.907,  13.867,
    14.014,  13.952,  13.796,  13.92 ,  14.051,  13.773,  13.837,
    13.745,  14.034,  13.923,  14.041,  14.077,  14.125,  13.989,
    14.174,  13.967,  13.952,  14.024,  14.171,  14.175,  14.091,
    14.267,  14.22 ,  14.071,  14.112,  14.174,  14.289,  14.146,
    14.356,  14.5  ,  14.265,  14.259,  14.406,  14.463,  14.473,
    14.413,  14.507])
sns.regplot(x=X_tr, y=y_tr.flatten());

然后检查性能:

shA_X = shared(X_tr)
with pm.Model() as linear_model:    
    alpha = pm.Normal("alpha", mu=14,sd=100)
    betas = pm.Normal("betas", mu=0, sd=100, shape=1)
    sigma = pm.HalfCauchy('sigma', beta=10, testval=1.)
    mu = alpha + betas * shA_X
    forecast = pm.Normal("forecast", mu=mu, sd=sigma, observed=y_tr)
    step = pm.NUTS()
    trace=pm.sample(3000, tune=1000)

enter image description here

为什么训练集的预测如此之差? 任何建议和想法都表示赞赏。

1 个答案:

答案 0 :(得分:3)

这个模型做得很好 - 我认为混淆是如何处理PyMC3对象(感谢您使用示例轻松实现!)。通常,PyMC3将用于量化模型中的不确定性。

例如,trace['betas'].mean()约为0.83(这将取决于您的随机种子),而最小二乘估计,例如,sklearn给出的将为0.826。同样,trace['alpha'].mean()给出2.34,而“真实”值为2.38。

您还可以使用迹线为您的最佳拟合线绘制许多不同的合理绘图:

for draw in trace[::100]:
    pred = draw['betas'] * X_tr + draw['alpha']
    plt.plot(X_tr, pred, '--', alpha=0.2, color='grey')


plt.plot(X_tr, y_tr, 'o');

enter image description here

请注意,这些是从您的数据的“最佳匹配”分布中提取的。您还使用sigma对噪点进行建模,您也可以绘制此值:

for draw in trace[::100]:
    pred = draw['betas'] * X_tr + draw['alpha']

    plt.plot(X_tr, pred, '-', alpha=0.2, color='grey')
    plt.plot(X_tr, pred + draw['sigma'], '-', alpha=0.05, color='red')
    plt.plot(X_tr, pred - draw['sigma'], '-', alpha=0.05, color='red');


plt.plot(X_tr, y_tr, 'o');

enter image description here

使用sample_ppc从后验分布中获取观察到的值,因此ppc_w['forecast']的每一行都是“下次”生成数据的合理方式。您可以这样使用该对象:

ppc_w = pm.sample_ppc(trace, 1000, linear_model,
                      progressbar=False)
for draw in ppc_w['forecast'][::5]:
    sns.regplot(X_tr, draw, scatter_kws={'alpha': 0.005, 'color': 'r'}, fit_reg=False)
sns.regplot(X_tr, y_tr, color='k', fit_reg=False);

enter image description here