我正在尝试计算均方损失的Hessian矩阵和梯度向量。数据生成非常简单:两个带有截距项的预测变量。贝塔系数为[1,2,3]。现在,我想以真实beta的值(即[1,2,3])计算Hessian和Gradient。 (Hessian应该是3 * 3的矩阵,并且梯度向量包含3个元素。)
Tensorflow计算出的Hessian和Gradient与分析结果有很大差异。而且,由Tensorflow产生的Hessian对于所有非对角元素都为零。
由于真正的beta几乎是普通最小二乘(OLS)的最优值,因此梯度应接近零。所以我怀疑我的Tensorflow代码是错误的。
import tensorflow as tf
import numpy as np
import random
random.seed(0)
np.random.seed(seed=0)
简单的数据生成:
sample_size=1000
dim_x=2 # two predictors
x1=np.random.normal(loc=0.0, scale=1.0, size=(sample_size,1))
x2=np.random.normal(loc=0.0, scale=1.0, size=(sample_size,1))
x_train=np.concatenate((x1,x2), axis=1)
Beta=np.array([1,2,3])
mean_y=Beta[0]+x_train.dot(Beta[1:3])
y_train=mean_y+np.random.normal(loc=0.0, scale=1.0, size=(sample_size,))
使用Tensorflow计算梯度矢量和Hessian:
sess = tf.Session()
x_input = tf.placeholder(tf.float64, [None, dim_x+1], name='x_input')
y_input = tf.placeholder(tf.float64, name='y_input')
coeff = tf.Variable(tf.convert_to_tensor(Beta,tf.float64), name='coeff')
y_fit = tf.multiply(x_input, coeff)
loss_op = tf.reduce_sum(tf.pow(y_input - y_fit, 2))/sample_size
gradients_node = tf.gradients(loss_op, coeff)
hessians_node = tf.hessians(loss_op, coeff)
init = tf.global_variables_initializer()
sess.run(init)
input_x=np.concatenate((np.ones([sample_size,1]),x_train),axis=1)
gradients,loss = sess.run([gradients_node, loss_op],feed_dict={x_input: input_x,y_input: y_train.reshape(-1, 1)})
hessians = sess.run(hessians_node,feed_dict={x_input: input_x,y_input: y_train.reshape(-1, 1)})
print(gradients)
print(hessians)
sess.close()
[array([ 0.20178232, 0.34121165, 0.09825671])]
[array([[ 2. , 0. , 0. ],
[ 0. , 1.95256525, 0. ],
[ 0. , 0. , 1.87503835]])]
计算梯度向量和Hessian分析:
Resid=y_train.reshape(-1, 1)-input_x.dot(Beta).reshape(-1, 1)
gradientsMannual=-2*np.transpose(Resid).dot(input_x)/sample_size
print(gradientsMannual)
hessiansMannual=2*np.transpose(input_x).dot(input_x)/sample_size
print(hessiansMannual)
[[ 0.10245713 0.06634371 0.00258758]]
[[ 2. -0.09051341 0.02723388]
[-0.09051341 1.95256525 -0.06145151]
[ 0.02723388 -0.06145151 1.87503835]]