我有一个非常简单的张量流模型。
但是,我似乎无法训练模型以提供所需的输出。
示例代码如下:
import numpy as np
import tensorflow as tf
input_values = np.array([
[[0.3], [0.3], [0.2]], # Sum 0.8
[[0.2], [0.3], [0.4]], # Sum 0.9
[[0.9], [0.05], [0.05]], # Sum 1.0
[[0.1], [0.1], [0.9]], # Sum 1.1
[[0.2], [0.5], [0.5]], # Sum 1.2
])
target_values = np.array([
[0.8], [0.9], [1.0], [1.1], [1.2]
])
targets = tf.Variable(target_values, trainable=False)
inputs = tf.placeholder(tf.float32, shape=(None, None, 1))
w0 = tf.get_variable('w0', shape=(1, 8), initializer=tf.truncated_normal_initializer())
b0 = tf.get_variable('b0', shape=8, initializer=tf.truncated_normal_initializer())
w1 = tf.get_variable('w1', shape=(8, 1), initializer=tf.truncated_normal_initializer())
b1 = tf.get_variable('b1', shape=1, initializer=tf.truncated_normal_initializer())
output = tf.reshape(inputs, (-1, 1))
output = tf.nn.relu(tf.matmul(output, w0) + b0)
output = tf.nn.relu(tf.matmul(output, w1) + b1)
output = tf.reshape(output, (5, 3, 1))
output = tf.reduce_sum(output, axis=1)
loss = tf.losses.mean_squared_error(output, targets)
feed_dict = {inputs: input_values}
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
for _ in range(1000):
sess.run(train_step, feed_dict=feed_dict)
print('final loss : %0.2f' % loss.eval(feed_dict))
print('output : \n%s' % output.eval(feed_dict))
但是,经过1,000次训练迭代之后,它会出现如下:
[array([[ 1.16963911],
[ 1.15914857],
[ 0.75045645],
[ 0.75721866],
[ 0.97209203]], dtype=float32)]
但我想要的输出是:
[array([[ 0.8],
[ 0.9],
[ 1.0],
[ 1.1],
[ 1.2]], dtype=float32)]
我希望这样一个简单的模型能够非常精确地适应并且只需要少量的训练迭代。
该模型足够灵活,可以匹配所需的输出。将权重设置为这些值将提供所需的输出:
w0 = tf.Variable([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.float32)
b0 = tf.get_variable('b0', shape=8, initializer=tf.zeros_initializer())
w1 = tf.Variable([[1], [0], [0], [0], [0], [0], [0], [0]], dtype=tf.float32)
b1 = tf.get_variable('b1', shape=1, initializer=tf.zeros_initializer())
我怎么能:
更新
如果我将激活功能从tf.nn.relu更改为tf.nn.sigmoid,它可以在100,000次迭代(我的机器上59秒)后合理地拟合。
这是输出:
[[ 0.8082366 ]
[ 0.906322 ]
[ 1.04254854]
[ 1.09336698]
[ 1.21020865]]