平方指数或RBF核的基本方程如下:
这里l是长度尺度,西格玛是方差参数。长度比例控制两个点看起来相似的方式,因为它只是放大了x和x'之间的距离。方差参数控制函数的平滑程度。
我想用我的训练数据集优化/训练这些参数(l和sigma)。我的训练数据集如下:
X :二维笛卡尔坐标作为输入数据
y :2-D坐标点的Wi-Fi设备的无线电信号强度(RSS)作为观察输出
根据sklearn,GaussianProcessRegressor类定义为:
class sklearn.gaussian_process.GaussianProcessRegressor(kernel=None, alpha=1e-10, optimizer=’fmin_l_bfgs_b’, n_restarts_optimizer=0, normalize_y=False, copy_X_train=True, random_state=None)
这里,optimizer
是一个字符串或可用L-BFGS-B算法作为默认优化算法(“fmin_l_bfgs_b”
)来调用。 optimizer
可以是内部支持的优化器之一,用于优化由字符串指定的内核参数,也可以是作为可调用方传递的外部定义的优化器。此外,scikit-learn中唯一可用的内部优化器是fmin_l_bfgs_b
。但是,我知道scipy package有更多的优化器。由于我想使用trust-region-reflective algorithm来优化超参数,我尝试按如下方式实现算法:
def fun_rosenbrock(Xvariable):
return np.array([10*(Xvariable[1]-Xvariable[0]**2),(1-Xvariable[0])])
Xvariable = [1.0,1.0]
kernel = C(1.0, (1e-5, 1e5)) * RBF(1, (1e-1, 1e3))
trust_region_method = least_squares(fun_rosenbrock,[10,20,30,40,50],bounds=[0,100], method ='trf')
gp = GaussianProcessRegressor(kernel=kernel, optimizer = trust_region_method, alpha =1.2, n_restarts_optimizer=10)
gp.fit(X, y)
由于我无法弄清楚在我的情况下实际上参数'fun'是什么,我使用this示例中的rosenbrock函数(示例位于页面底部)。我在控制台中收到以下错误。
我使用 scipy包来优化内核参数是否正确?如何打印参数的优化值?在我的案例中,scipy.optimize.least_squares中的参数'fun'是什么?
谢谢!
答案 0 :(得分:3)
这里有三个主要问题:
作为一个部分工作的例子,忽略内核定义以强调优化器:
import numpy as np
from scipy.optimize import minimize,least_squares
from sklearn.gaussian_process import GaussianProcessRegressor
def trust_region_optimizer(obj_func, initial_theta, bounds):
trust_region_method = least_squares(1/obj_func,initial_theta,bounds,method='trf')
return (trust_region_method.x,trust_region_method.fun)
X=np.random.random((10,4))
y=np.random.random((10,1))
gp = GaussianProcessRegressor(optimizer = trust_region_optimizer, alpha =1.2, n_restarts_optimizer=10)
gp.fit(X, y)
scipy优化器返回结果对象,使用rosenbrock测试函数的最小化作为示例:
from scipy.optimize import least_squares,rosen
res=least_squares(rosen,np.array([0,0]),method='trf')
如上所示,可以使用以下方法访问优化值:
res.x
以及要最小化的函数的结果值:
res.fun
这就是'有趣的'参数代表。但是现在优化器在内部工作,您需要从scikit-learn访问生成的函数值:
gp.log_marginal_likelihood_value_