使用TensorFlow计算简单线性回归时,为什么会得到[nan]?

时间:2017-12-21 15:59:48

标签: python tensorflow machine-learning linear-regression

当我使用TensorFlow计算简单线性回归时,得到[nan],包括:w,b和loss。

这是我的代码:

import tensorflow as tf

w = tf.Variable(tf.zeros([1]), tf.float32)
b = tf.Variable(tf.zeros([1]), tf.float32)
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

liner = w*x+b

loss = tf.reduce_sum(tf.square(liner-y))

train = tf.train.GradientDescentOptimizer(1).minimize(loss)

sess = tf.Session()

x_data = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]
y_data = [265000, 324000, 340000, 412000, 436000, 490000, 574000, 585000, 680000]                                                    

sess.run(tf.global_variables_initializer())

for i in range(1000):
    sess.run(train, {x: x_data, y: y_data})

nw, nb, nloss = sess.run([w, b, loss], {x: x_data, y: y_data})

print(nw, nb, nloss)

输出:

[ nan] [ nan] nan

Process finished with exit code 0

为什么会发生这种情况,我该如何解决?

2 个答案:

答案 0 :(得分:1)

使用如此高的学习率(在您的情况下为1),您正在满溢。尝试学习率为0.001。此外,您的数据需要除以1000并且迭代次数增加,它应该工作。这是我测试并完美运行的代码。

x_data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y_data = [265, 324, 340, 412, 436, 490, 574, 585, 680]

plt.plot(x_data, y_data, 'ro', label='Original data')
plt.legend()
plt.show()

W = tf.Variable(tf.random_uniform([1], 0, 1))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b

loss = tf.reduce_mean(tf.square(y - y_data))

optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)

for step in range(0,50000):
   sess.run(train)
   print(step, sess.run(loss))
print (step, sess.run(W), sess.run(b))

plt.plot(x_data, y_data, 'ro')
plt.plot(x_data, sess.run(W) * x_data + sess.run(b))
plt.legend()
plt.show()

答案 1 :(得分:1)

这给出了我认为的解释:

for i in range(10):
     print(sess.run([train, w, b, loss], {x: x_data, y: y_data}))

给出以下结果:

[None, array([  4.70380012e+10], dtype=float32), array([ 8212000.], dtype=float32), 2.0248419e+12] 
[None, array([ -2.68116614e+19], dtype=float32), array([ -4.23342041e+15], dtype=float32),
6.3058345e+29] 
[None, array([  1.52826476e+28], dtype=float32), array([  2.41304958e+24], dtype=float32), inf] [None, array([
-8.71110858e+36], dtype=float32), array([ -1.37543819e+33], dtype=float32), inf] 
[None, array([ inf], dtype=float32), array([ inf], dtype=float32), inf]

你的学习成绩太大,所以你过度纠正"每次迭代时w的值(见它在负和正之间振荡,绝对值增加)。您获得越来越高的值,直到某些东西达到无穷大,从而产生Nan值。只需降低(很多)学习率。