pymc3多变量traceplot颜色编码

时间:2017-04-03 19:57:19

标签: python formatting pymc3

我是pymc3的新手,我在生成一个易于阅读的traceplot时遇到了麻烦。 我将4个多元高斯的混合拟合到数据集中的某些(x,y)点。模型运行良好。我的问题是关于操作pm.traceplot()命令以使输出更加用户友好。 这是我的代码:

import matplotlib.pyplot as plt
import numpy as np
model = pm.Model()
N_CLUSTERS = 4
with model:
    #cluster prior
    w = pm.Dirichlet('w', np.ones(N_CLUSTERS))
    #latent cluster of each observation
    category = pm.Categorical('category', p=w, shape=len(points))

    #make sure each cluster has some values:
    w_min_potential = pm.Potential('w_min_potential', tt.switch(tt.min(w) < 0.1, -np.inf, 0))
    #multivariate normal means
    mu = pm.MvNormal('mu', [0,0], cov=[[1,0],[0,1]], shape = (N_CLUSTERS,2) )
    #break symmetry
    pm.Potential('order_mu_potential', tt.switch(
                                                tt.all(
                                                   [mu[i, 0] < mu[i+1, 0] for i in range(N_CLUSTERS - 1)]), -np.inf, 0))
    #multivariate centers
    data = pm.MvNormal('data', mu =mu[category], cov=[[1,0],[0,1]],  observed=points)

 with model:
     trace = pm.sample(1000)

pm.traceplot(trace, ['w', 'mu'])的调用会生成此图片: Default traceplot() output

正如您所看到的,它是模糊的,其中平均峰值对应于x或y值,哪些是成对的。我已按如下方式管理了一个解决方法:

from cycler import cycler
#plot the x-means and y-means of our data!
fig, (ax0, ax1) = plt.subplots(nrows=2)
plt.xlabel('$\mu$')
plt.ylabel('frequency')
for i in range(4):
    ax0.hist(trace['mu'][:,i,0], bins=100, label='x{}'.format(i), alpha=0.6);
    ax1.hist(trace['mu'][:,i,1],bins=100, label='y{}'.format(i), alpha=0.6);
ax0.set_prop_cycle(cycler('color', ['c', 'm', 'y', 'k']))
ax1.set_prop_cycle(cycler('color', ['c', 'm', 'y', 'k']))
ax0.legend()
ax1.legend()

这会产生以下更清晰的情节: A more legible histogram of means

我在这里查看了pymc3文档和最近的问题,但无济于事。我的问题是:是否有可能通过pymc3中的内置方法使用matplotlib完成我在这里所做的事情,如果有,怎么做?

2 个答案:

答案 0 :(得分:0)

至少在最新版本中,您可以按以下方式使用compact=True

pm.traceplot(trace, var_names = ['parameters'], compact=True)

获得一张包含所有参数的图形 https://arviz-devs.github.io/arviz/_modules/arviz/plots/traceplot.html

中的文档

但是,我无法使行之间的颜色不同

答案 1 :(得分:0)

多维变量和不同链之间的更好区分最近被添加到ArviZ中(PyMC3库依赖于绘图)。

在ArviZ最新版本中,您应该可以做到:

az.plot_trace(trace, compact=True, legend=True)

以获得每个变量的不同尺寸(按颜色区分)和不同的链(按线型区分)。默认设置是使用matplotlib的默认颜色周期和4种不同的线型,实线,虚线,点线和点划线。通过使用compact_prop自定义尺寸表示和chain_prop自定义链表示,可以将这两个属性设置为自定义美学和自定义值。此外,如果使用compact,则最好使用combined=True来减少第一列中的混乱情况。例如:

az.plot_trace(trace, compact=True, combined=True, legend=True, chain_prop=("ls", "-"))

将使用来自所有链的数据在第一列中绘制KDE,并使用实线样式绘制所有链(由于合并的arg,仅与第二列相关)。将显示两个图例,一个用于链信息,另一个用于紧凑信息。