我尝试在tensorflow中编写以下函数。由于张量不能迭代,我不知道如何在tensorflow中处理这一行。
kernel_matrix = np.squeeze(np.array([[rq_convariance(1.5, x1, x2) for x1 in plot_xs] for x2 in plot_xs]))
我想使用' tf.map_fn',但我的函数有两个变量,它们不能使用&#tf.map_fn'。整个编码如下。
import numpy as np
from scipy import interpolate
from numpy.linalg import inv
import tensorflow as tf
import tensorflow.contrib.distributions as tcd
def rq_convariance(theta, x1, x2):
kernel = np.exp(-(x1 - x2)**2 / 2 * theta**2)
return kernel
def expectation(x):
plot_xs = np.reshape(np.linspace(-5, 5, 300), (300, 1))
kernel_matrix = np.squeeze(np.array([[rq_convariance(1.5, x1, x2) for x1 in plot_xs] for x2 in plot_xs]))
inverse_kernel_matrix = inv(kernel_matrix)
def reference_point():
# calculate the convariance
convariance = np.squeeze(kernel_matrix)
sampled_funcs = np.random.multivariate_normal(np.ones(len(plot_xs)), convariance, size=1)
return sampled_funcs
ref_point = np.transpose(reference_point())
kernel_vector = np.transpose(np.array([rq_convariance(1.5, x1, x) for x1 in plot_xs]))
mu = kernel_vector.dot(inverse_kernel_matrix).dot(ref_point)
return mu
答案 0 :(得分:1)
来自Numpy:
def rq_convariance(theta, x1, x2):
kernel = np.exp(-(x1 - x2)**2 / 2 * theta**2)
return kernel
#Setting a smaller size
plot_xs = np.reshape(np.linspace(-5, 5, 5), (5, 1))
kernel_matrix = np.squeeze(np.array([[rq_convariance(1.5, x1, x2) for x1 in plot_xs] for x2 in plot_xs]))
#Kernel_matrix output
[[1.0000000, 0.0008838, 0.0000000, 0.0000000, 0.0000000],
[0.0008838, 1.0000000, 0.0008838, 0.0000000, 0.0000000],
[0.0000000, 0.0008838, 1.0000000, 0.0008838, 0.0000000],
[0.0000000, 0.0000000, 0.0008838, 1.0000000, 0.0008838],
[0.0000000, 0.0000000, 0.0000000, 0.0008838, 1.0000000]])
To Tensorflow:
t_plot_xs = tf.reshape(tf.linspace(-5., 5., 5, name="linspace"), (5,1))
#Let broadcasting do the trick
t_kernel = tf.exp(-0.5*theta**2*(t_plot_xs - tf.transpose(t_plot_xs))**2)
with tf.Session() as sess:
print(sess.run(t_kernel))
#Output
[[1.0000000 0.0008838 0.0000000 0.0000000 0.0000000]
[0.0008838 1.0000000 0.0008838 0.0000000 0.0000000]
[0.0000000 0.0008838 1.0000000 0.0008838 0.0000000]
[0.0000000 0.0000000 0.0008838 1.0000000 0.0008838]
[0.0000000 0.0000000 0.0000000 0.0008838 1.0000000]]