我正在尝试在TensorFlow中学习线性变换模型的权重。似乎训练损失非常低,但产生的权重与转换中使用的实际权重大不相同。
除了使用W
和b
的标量之外我做同样的事情,我能够恢复原始参数,但当我将它扩展到矩阵乘法时,如下所示,“学到的”参数与初始化时的参数略有不同(随机),并且在使用的实际值(_W
,_b
)附近没有。如何让这个模型报告线性变换的学习值?
def get_linear_relationship(x_shape, y_shape):
W = np.random.random((y_shape[0], x_shape[0]))
b = np.random.random(y_shape)
x = np.random.random(x_shape)
y = np.matmul(W, x) + b
return x, y, W, b
def linear_model():
with tf.Session() as sess:
x_shape = (3, 1)
y_shape = (3, 1)
_x, _y, _W, _b = get_linear_relationship(x_shape, y_shape)
W = tf.Variable(np.random.random((y_shape[0], x_shape[0])), dtype=tf.float32)
b = tf.Variable(np.random.random(y_shape), dtype=tf.float32)
x = tf.placeholder(shape=x_shape, dtype=tf.float32)
model = tf.matmul(W, x) + b
y = tf.placeholder(shape=y_shape, dtype=tf.float32)
loss = tf.reduce_sum(tf.square(model - y))
optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)
# initialization
init = tf.global_variables_initializer()
sess.run(init)
result = sess.run(loss, {x: _x, y: _y})
print("Random: ", result)
for i in range(1000):
sess.run(train, {x: _x, y: _y})
print("Actual W, b:\n", _W, _b)
print("Learned:\n", sess.run([W, b]))
trained_result = sess.run(loss, {x: _x, y: _y})
print("Training loss: ", trained_result)
以下是输出示例:
Random: 0.854376
Actual W, b:
[[ 0.68397062 0.48808535 0.0248331 ]
[ 0.15806422 0.37479114 0.1709631 ]
[ 0.45631878 0.7785539 0.69242146]] [[ 0.92418495]
[ 0.41903298]
[ 0.92627156]]
Learned:
[array([[ 0.17904782, 0.58858037, 0.05749775],
[ 0.63658452, 0.47282287, 0.40709457],
[ 0.05679467, 0.62718385, 0.39558661]], dtype=float32), array([[ 0.87702203],
[ 0.23417197],
[ 1.13123035]], dtype=float32)]
Training loss: 0.00073651
正如您所看到的,模型中的变量W
和b
与从_x
到_y
的转换中实际发现的变量完全不同,尽管损失函数大幅减少。为什么它们不同,我如何获得与真实参数接近的学习参数?
感谢您的帮助。
答案 0 :(得分:0)
问题是您只是一遍又一遍地训练使用相同的数据点。由于模型只有一个数据点可以从中学习线性关系(具有多个自由度),因此没有希望匹配/外推任何数据点以外的任何数据点。要简单地解决这个问题,请为模型提供更多转换数据点的示例以供学习。修改你的训练循环
for i in range(10000):
_x = np.random.random(x_shape)
sess.run(train, {x: _x, y: np.matmul(_W, _x) + _b})
您应该能够恢复原始的W
和b
参数。