如何指定scikit-learn的高斯过程回归的先验?

时间:2019-03-01 20:20:23

标签: python scikit-learn regression

here所述,scikit-learn的高斯过程回归(GPR)允许“无需先验拟合即可进行预测(基于GP先验)”。但是我对先验的状态有所了解(即,它不应该简单地具有零均值,而也许我的输出@pytest.mark.parametrize(('values','expected_code'),(_internal_function,[100,200,300]))与输入y(即{{1})成线性比例关系}。如何将这些信息编码为GPR?

以下是一个有效的示例,但对于我的先前示例,它假设均值为零。我read说:“ GaussianProcessRegressor不允许指定均值函数,始终假定它是零函数,突出了均值函数在计算后验中的作用已减弱。”我相信这是custom kernels(例如异方差)背后的动机,它们在不同的X上具有可变的规模,尽管我仍在尝试更好地了解它们提供的功能。有没有办法绕过零均值先验,以便可以在scikit-learn中指定任意先验?

y = X

1 个答案:

答案 0 :(得分:0)

这里是有关如何在sklearn GPR模型中使用先验均值函数的示例。

import numpy as np
from matplotlib import pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel

A=np.linspace(5,25,num=100)
# prior mean function
prior_beta=12-0.3*A
# true function
true_beta=20-0.7*A

rng = np.random.seed(44)
# Training data
size=15
ind=np.random.randint(0,100,size=size)
# generate the posterior variance (noisy samples)
var_=np.random.uniform(0.1,10.0,size=size)
A_=A[ind][:, np.newaxis]
beta_=true_beta[ind]-prior_beta[ind]
beta_1=true_beta[ind]

plt.figure()

kernel = ConstantKernel(4) * RBF(length_scale=2, length_scale_bounds=(1e-3, 1e2))
gp = GaussianProcessRegressor(kernel=kernel,
                              alpha=var_,optimizer=None).fit(A_, beta_)
X_ = np.linspace(5, 25, 100)
y_mean, y_cov = gp.predict(X_[:, np.newaxis], return_cov=True)
# Now you add the prior mean function back
y_mean=y_mean+12-0.3*X_
plt.plot(X_, y_mean, 'k', lw=3, zorder=9, label='predicted')
plt.fill_between(X_, y_mean - 3*np.sqrt(np.diag(y_cov)),
                 y_mean + 3*np.sqrt(np.diag(y_cov)),
                 alpha=0.5, color='k', label='+-3sigma')
plt.plot(A,true_beta, 'r', lw=3, zorder=9,label='truth')
plt.plot(A,prior_beta, 'blue', lw=3, zorder=9,label='prior')
plt.errorbar(A_[:,0], beta_1, yerr=3*np.sqrt(var_), fmt='x',ecolor='g',marker='s', 
mfc='g', ms=10,capsize=6,label='training set')

plt.title("Initial: %s\n"% (kernel))
plt.legend()
plt.show()

OUTPUT