如何将pymc3变量分配给当前活动的模型?

时间:2018-03-30 10:57:45

标签: python pymc3

PyMC3中你可以这样做

basic_model = pm.Model()

with basic_model:

    # Priors for unknown model parameters
    alpha = pm.Normal('alpha', mu=0, sd=10)
    beta = pm.Normal('beta', mu=0, sd=10, shape=2)
    sigma = pm.HalfNormal('sigma', sd=1)

    # Expected value of outcome
    mu = alpha + beta[0]*X1 + beta[1]*X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

并且所有变量(pm.Normal,...)将被分配"到basic_model实例。

From the docs

  

第一行,

basic_model = Model()
     

创建一个新的Model对象,它是模型random的容器   变量

     

在模型实例化之后,随后的规范   模型组件在with语句中执行:

with basic_model:
     

这创建了一个上下文管理器,我们的basic_model作为上下文,   包括所有语句,直到缩进块结束。这意味着   所有PyMC3对象都在with下面的缩进代码块中引入   语句被添加到幕后的模型中。没有这个   上下文管理器成语,我们将被迫手动关联每个   在我们创建它们之后立即使用basic_model。如果你   尝试创建一个没有with model:语句的新随机变量,   它会引发错误,因为没有明显的模型   要添加的变量。

我认为这对图书馆来说非常优雅。这是如何实际实现的?

我能想到的唯一方法就是本着这个精神:

class Model:
    active_model = None
    def __enter__(self):
        Model.active_model = self
    def __exit__(self, *args, **kwargs):
        Model.active_model = None

class Normal:
    def __init__(self):
        if Model.active_model is None:
            raise ValueError("cant instantiate variable outside of Model")
        else:
            self.model = Model.active_model

它适用于我简单的REPL测试,但我不确定这是否有一些陷阱,实际上就是这么简单。

1 个答案:

答案 0 :(得分:2)

你非常接近,它甚至与你的实施非常相似。请注意,threading.local用于存储对象,它作为列表进行维护,以允许嵌套多个模型,并允许多处理。在实际实现中有一点额外的内容,允许在输入我删除的模型上下文时设置theano配置:

class Context(object):
    contexts = threading.local()

    def __enter__(self):
        type(self).get_contexts().append(self)
        return self

    def __exit__(self, typ, value, traceback):
        type(self).get_contexts().pop()

    @classmethod
    def get_contexts(cls):
        if not hasattr(cls.contexts, 'stack'):
            cls.contexts.stack = []
        return cls.contexts.stack

    @classmethod
    def get_context(cls):
        """Return the deepest context on the stack."""
        try:
            return cls.get_contexts()[-1]
        except IndexError:
            raise TypeError("No context on context stack")

Model类子类Context,因此在编写算法时,我们可以从上下文管理器中调用Model.get_context()并访问该对象。这相当于您的Model.active_model