我对高斯进程和python还是很陌生。 我正在尝试为3d模型生成一个非常简单的高斯回归。
我有一个非常简单的Python函数代码:
import numpy as np
def exponential_cov(x, y, params):
return params[0] * np.exp( -0.5 * params[1] * np.subtract.outer(x, y)**2)
def conditional(x_new, x, y, params):
B = exponential_cov(x_new, x, params)
C = exponential_cov(x, x, params)
A = exponential_cov(x_new, x_new, params)
mu = np.linalg.inv(C).dot(B.T).T.dot(y)
sigma = A - B.dot(np.linalg.inv(C).dot(B.T))
return(mu.squeeze(), sigma.squeeze())
import matplotlib.pylab as plt
# GP PRIOR
tu = [1, 10]
Si_tu = exponential_cov(0, 0, tu)
xpts = np.arange(-5, 5, step=0.01)
plt.errorbar(xpts, np.zeros(len(xpts)), yerr=Si_tu, capsize=0, color='#95daed', alpha=0.5, label='error') #error
plt.plot(xpts, np.zeros(len(xpts)), linestyle='dashed', color='#3105b2', linewidth=2.5, label='mu'); #mu
# GP FOR 1ST POINT
x = [1.]
y = np.sin(x)+np.cos(np.sqrt(15)*x)
Si_1 = exponential_cov(x, x, tu)
def predict(x, data, kernel, params, sigma, t):
k = [kernel(x, y, params) for y in data]
Sinv = np.linalg.inv(sigma)
y_pred = np.dot(k, Sinv).dot(t)
sigma_new = kernel(x, x, params) - np.dot(k, Sinv).dot(k)
return y_pred, sigma_new
x_pred = np.linspace(-5, 5, 1000) #change step here!!
print "x_pred="
print(x_pred)
predictions = [predict(i, x, exponential_cov, tu, Si_1, y) for i in x_pred]
y_pred, sigmas = np.transpose(predictions)
print "y_pred ="
print(y_pred )
print "sigmas ="
print(sigmas )
# GP FOR 2ND POINT
m, s = conditional([-1], x, y, tu)
y2 = np.sin(-1)+np.cos(np.sqrt(15)*(-1))
x.append(-1)
y=np.append(y,y2)
Si_2 = exponential_cov(x, x, tu)
predictions = [predict(i, x, exponential_cov, tu, Si_2, y) for i in x_pred]
y_pred, sigmas = np.transpose(predictions)
print "y_pred ="
print(y_pred )
print "sigmas ="
print(sigmas )
通过使用此代码,我对函数np.sin(x) + np.cos(np.sqrt(15) * x)
的拟合结果非常好,但是我真正想做的是对函数Z = np.sin(2*X) * np.cos(2*Y) / 2
尝试相同的高斯过程。
我知道这个想法基本相同,但是我无法将python代码适应[x,y]输入以获得z。
非常感谢您的帮助,提示或链接!
答案 0 :(得分:0)
在前面,函数的输入是1-D,然后新函数是2-D。因此,您必须更改协方差函数,例如,使用基于ard的内核,请参考cook book for kernel。另外,您可以对2-D做各向同性核,只需确保选择合适的距离函数(例如L2-norm)和单个长度刻度即可。