在pymc3 Logistic回归中使用NUTS调试收敛和慢采样问题

时间:2018-08-01 21:50:35

标签: python optimization statistics pymc3 mcmc

我正在从GitHub issue交叉发布同一主题。

我正在尝试使用pymc中的默认NUTS算法确定慢采样率和收敛问题的原因,该方法在150万观察值和500个协变量(1500000 x 500设计矩阵)的数据集上拟合了标准逻辑模型。

下面是一些代码来创建这种大小的玩具数据集:

import numpy as np
covariates = np.random.randn(1500000, 499)
covariates = np.hstack((np.ones((1500000, 1)), covariates))

true_coefficients = 5 * np.random.rand(500)
true_logits = np.dot(covariates, true_coefficients) 
true_probs = 1.0 / (1.0 + np.exp(-true_logits))
observed_labels = (np.random.rand(1500000) < true_probs).astype(np.int32)

例如,可以使用以下代码从此处轻松创建statsmodelssklearn.linear_model.LogisticRegression模型拟合:

from sklearn.linear_model import LogisticRegression

logistic = LogisticRegression(max_iter=1000, fit_intercept=False, verbose=1)
logistic.fit(covariates, observed_labels)
# above takes 3-5 minutes on my machine, appears to converge well.

import matplotlib.pyplot as plt
plt.scatter(true_coefficients, logistic.coef_[0, :])
plt.show()
# above scatter plot shows good accuracy in the point estimate.

即使对于完整的数据集,sklearn模型也可以合理地收敛,并且与真实系数相比,可以得出精度较高的求解参数向量。

但是使用下面的pymc代码,无论使用NUTS还是简单的Metropolis采样器,我经常会看到不可接受的缓慢采样时间。我尝试更改许多默认设置,包括对数据进行二次采样以及其他各种技巧,但是都没有运气。

import pymc3 as pm
import theano.tensor as tt

with pm.Model() as logistic_model:
    beta = pm.Normal('beta', 0.0, sd=3.0, shape=500)
    p = 1.0 / (1.0 + tt.exp(-tt.dot(covariates, beta)))
    likelihood = pm.Bernoulli('likelihood', p, observed=observed_labels)

with logistic_model:
    tr = pm.sample(1000, njobs=2, nchains=2)

# In the sample step above I have also tried all types of variations on using
# step = pm.NUTS(), step = pm.Metropolis(), start = pm.find_MAP(),
# start = {'beta': np.zeros(500)}, and adjustment of all sorts of options with
# kwargs to these steppers.

# I have also tried wrapping `covariates` and `observed_labels` with pm.Minibatch
# with batch sizes ranging from 1000 to 200000 observations per sample. With NUTS,
# the minibatches made it run drastically slower, over 40 seconds per sample.

对于NUTS算法,我发现如果我不手动将起始值设置为零向量或不使用pm.find_MAP(在pymc文档中不建议使用),那么我将得到一个“错误”初始化过程中出现“能量”错误:

RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
  File "/Users/espears/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 73, in run
    self._start_loop()
  File "/Users/espears/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 113, in _start_loop
    point, stats = self._compute_point()
  File "/Users/espears/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 139, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/Users/espears/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 247, in step
    apoint, stats = self.astep(array)
  File "/Users/espears/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 117, in astep
    'might be misspecified.' % start.energy)
ValueError: Bad initial energy: inf. The model might be misspecified.
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Bad initial energy: inf. The model might be misspecified.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
~/programming/pymc-testing/logistic_example.py in <module>()
     45 with logistic_model:
     46     #step, start = pm.Metropolis(), {'beta': np.zeros(num_coefficients)}
---> 47     tr = pm.sample(1000, njobs=2, nchains=2)
     48 print(f"{time.time() - st}: finished_sampling...")
     49

~/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    449             _print_step_hierarchy(step)
    450             try:
--> 451                 trace = _mp_sample(**sample_args)
    452             except pickle.PickleError:
    453                 _log.warning("Could not pickle model, sampling singlethreaded.")

~/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
    999         try:
   1000             with sampler:
-> 1001                 for draw in sampler:
   1002                     trace = traces[draw.chain - chain]
   1003                     if trace.supports_sampler_stats and draw.stats is not None:

~/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/parallel_sampling.py in __iter__(self)
    303
    304         while self._active:
--> 305             draw = ProcessAdapter.recv_draw(self._active)
    306             proc, is_last, draw, tuning, stats, warns = draw
    307             if self._progress is not None:

~/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
    221         if msg[0] == 'error':
    222             old = msg[1]
--> 223             six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
    224         elif msg[0] == 'writing_done':
    225             proc._readable = True

~/anaconda3/envs/py36-pymc/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

RuntimeError: Chain 0 failed.

对于NUTS或简单的Metropolis算法,这都是非常出乎意料的,因为对数似然函数的梯度对于该模型而言非常简单,我使用的是合理且有用的先验值(将标准偏差设置为3),并且1500000 x 500的数据大小非常小,可以舒适地容纳在内存中,并且不会引起标准统计包(例如statsmodels或scikit-learn(使用二阶导数等),与NUTS相同)的问题)。

我正在寻找有关需要哪些采样器设置的指导,以便获得完整数据的快速采样率,以便可以更长的采样时间进行收敛。

0 个答案:

没有答案