stan矢量参数的一个不错的pystan轨迹图

时间:2018-10-15 10:02:03

标签: python stan pystan

我正在Stan中进行多元回归。

我想要回归器/设计矩阵的beta矢量参数的迹线图。

当我执行以下操作时:

fit = model.sampling(data=data, iter=2000, chains=4)
fig = fit.plot('beta')

我得到一个非常恐怖的图像:

horrid trace plot for vector parameter

我在追求一些用户友好的东西。我设法破解了以下内容,使其更接近我的追求。

Nicer tracer plot for vector parameter

我的黑客如下插入pystan的背面。

r = fit.extract() # r for results
from pystan.external.pymc import plots
param = 'beta'
beta = r[param] 
name = df.columns.values.tolist()
(rows, cols) = beta.shape
assert(len(df.columns) == cols)
values = {param+'['+str(k+1)+'] '+name[k]: 
    beta[:,k] for k in range(cols)}
fig = plots.traceplot(values, values.keys())
for a in fig.axes:
    # shorten the y-labels
    l = a.get_ylabel()
    if l == 'frequency': 
        a.set_ylabel('freq')
    if l=='sample value': 
        a.set_ylabel('val')
fig.set_size_inches(8, 12)
fig.tight_layout(pad=1)
fig.savefig(g_dir+param+'-trace.png', dpi=125)
plt.close()

我的问题-当然我错过了一些事情-但是有没有更简单的方法可以从pystan获得我想要的矢量参数输出?

2 个答案:

答案 0 :(得分:1)

发现ArviZ模块可以很好地做到这一点。

可以在这里找到ArviZ:https://arviz-devs.github.io/arviz/

答案 1 :(得分:0)

我对此也很挣扎,只是找到了一种方法来提取traceplot的参数(我已经知道beta)。

合适时,可以将其保存到数据框中:

fit_df = fit.to_dataframe()

现在您有了一个新变量,即数据框。是的,我花了一段时间才发现pystan有一种直接的方法来将拟合保存到数据框中。

有了它,您就可以检查数据框。您可以通过打印按键来查看其标题:

fit_df.keys()

输出是这样的:

Index([u'chain', u'chain_idx', u'warmup', u'accept_stat__', u'energy__',
       u'n_leapfrog__', u'stepsize__', u'treedepth__', u'divergent__',
       u'beta[1,1]', ...
       u'eta05[892]', u'eta05[893]', u'eta05[894]', u'eta05[895]',
       u'eta05[896]', u'eta05[897]', u'eta05[898]', u'eta05[899]',
       u'eta05[900]', u'lp__'],
      dtype='object', length=9037)

现在,您拥有了所需的一切! Beta在列中以及链ID。这就是绘制beta和traceplot所需要的。因此,您可以根据需要任意操作它,并根据需要自定义图形。我将向您展示如何做到这一点的示例:

chain_idx = fit_df['chain_idx']
beta11 = fit_df['beta[1,1]']
beta12 = fit_df['beta[1,2]']

plt.subplots(figsize=(15,3))
plt.subplot(1,4,1)
sns.kdeplot(beta11)
plt.subplot(1,4,2)
plt.plot(chain_idx, beta11)

plt.subplot(1,4,3)
sns.kdeplot(beta12)
plt.subplot(1,4,4)
plt.plot(chain_idx, beta12)

plt.tight_layout()
plt.show()

The image from the above plot!

我希望它会有所帮助(如果您仍然需要);)