声明pymc3的theano变量

时间:2017-09-29 00:27:13

标签: numpy theano pymc pymc3

我在使用pymc3复制pymc2代码时遇到问题。

我认为这是因为pymc3正在使用与我正在使用的numpy操作不兼容的theano类型变量。所以我使用的是@ theano.decorator:

我有这个功能:

with pymc3.Model() as model:

    z_stars         = pymc3.Uniform('z_star',    self.z_min_ssp_limit, self.z_max_ssp_limit)
    Av_stars        = pymc3.Uniform('Av_star',   0.0, 5.00)
    sigma_stars     = pymc3.Uniform('sigma_star',0.0, 5.0)

    #Fit observational wavelength
    ssp_fit_output = self.ssp_fit_theano(z_stars, Av_stars, sigma_stars, 
                                         self.obj_data['obs_wave_resam'], 
                                         self.obj_data['obs_flux_norm_masked'], 
                                         self.obj_data['basesWave_resam'], 
                                         self.obj_data['bases_flux_norm'], 
                                         self.obj_data['int_mask'], 
                                         self.obj_data['normFlux_obs'])

    #Define likelihood
    like = pymc.Normal('ChiSq', mu=ssp_fit_output, 
                       sd=self.obj_data['obs_fluxEr_norm'], 
                       observed=self.obj_data['obs_fluxEr_norm'])

    #Run the sampler
    trace = pymc3.sample(iterations, step=step, start=start_conditions, trace=db)

其中:

@theano.compile.ops.as_op(itypes=[t.dscalar,t.dscalar,t.dscalar,t.dvector,
                                  t.dvector,t.dvector,t.dvector,t.dvector,t.dscalar],
                          otypes=[t.dvector])
def ssp_fit_theano(self, input_z, input_sigma, input_Av, obs_wave, obs_flux_masked, 
                   rest_wave, bases_flux, int_mask, obsFlux_mean):
   ...
   ...

前三个变量是标量(来自pymc3均匀分布)。该 剩下的变量是numpy数组,最后一个是float。但是,我 得到这个''numpy.ndarray'对象没有属性'type'“错误:

  File "/home/user/anaconda/lib/python2.7/site-packages/theano/gof/op.py", line 615, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/user/anaconda/lib/python2.7/site-packages/theano/gof/op.py", line 963, in make_node
    if not all(inp.type == it for inp, it in zip(inputs, self.itypes)):
  File "/home/user/anaconda/lib/python2.7/site-packages/theano/gof/op.py", line 963, in <genexpr>
    if not all(inp.type == it for inp, it in zip(inputs, self.itypes)):
AttributeError: 'numpy.ndarray' object has no attribute 'type'

请欢迎任何正确方向的建议。

1 个答案:

答案 0 :(得分:2)

当我从pymc2转到pymc3时,我有一堆浪费时间的时间。我认为,问题在于文档非常糟糕。我怀疑他们忽略了文档,因为代码仍在不断发展。 3条评论/建议: