张量流的简单线性回归

时间:2017-10-07 16:03:48

标签: machine-learning tensorflow linear-regression

我是张量流和机器学习的初学者。我想通过tensorflow尝试一个简单的线性回归示例。

但是在3700年后,损失不会减少。我不知道出了什么问题?

显然,我们得到了W = 3.52, b = 2.8865。所以y = 3.52*x + 2.8865。在测试数据x = 11, y = 41.6065时。但这是错误的。因为培训数据x = 10, y = 48.712

下面发布的代码和损失。

#Goal: predict the house price in 2017 by linear regression method
#Step: 1. load the original data
#      2. define the placeholder and variable
#      3. linear regression method
#      4. launch the graph

from __future__ import print_function

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# 1. load the original data
price = np.asarray([6.757, 12.358, 10.091, 11.618, 14.064, 
    16.926, 17.673, 22.271, 26.905, 34.742, 48.712])
year = np.asarray([0,1,2,3,4,5,6,7,8,9,10])
n_samples = price.shape[0]


# 2. define the placeholder and variable
x = tf.placeholder("float")
y_ = tf.placeholder("float")


W = tf.Variable(np.random.randn())
b = tf.Variable(np.random.randn())


# 3. linear regression method
y = tf.add(tf.multiply(x, W), b)

loss =  tf.reduce_mean(tf.square(y - y_))/(2*n_samples)
training_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)


# 4. launch the graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(10000):
        for (year_epoch, price_epoch) in zip(year, price):  
            sess.run(training_step, feed_dict = {x: year_epoch, y_: price_epoch})

        if (epoch+1) % 50 == 0:
            loss_np = sess.run(loss, feed_dict={x: year, y_: price})
            print("Epoch: ", '%04d' % (epoch+1), "loss = ", "{:.9f}".format(loss_np), "W = ", sess.run(W), "b = ", sess.run(b))

    # print "Training finish"
    training_loss = sess.run(loss, feed_dict = {x: year, y_: price})
    print("Training cost = ", training_loss, "W = ", sess.run(W), "b = ", sess.run(b), '\n')

损失是:

Epoch:  0050 loss =  1.231071353 W =  3.88227 b =  0.289058
Epoch:  0100 loss =  1.207471132 W =  3.83516 b =  0.630129
Epoch:  0150 loss =  1.189429402 W =  3.79423 b =  0.926415
Epoch:  0200 loss =  1.175611973 W =  3.75868 b =  1.1838
Epoch:  0250 loss =  1.165009260 W =  3.72779 b =  1.40738
Epoch:  0300 loss =  1.156855702 W =  3.70096 b =  1.60161
Epoch:  0350 loss =  1.150570631 W =  3.67766 b =  1.77033
Epoch:  0400 loss =  1.145712137 W =  3.65741 b =  1.9169
Epoch:  0450 loss =  1.141945601 W =  3.63982 b =  2.04422
Epoch:  0500 loss =  1.139016271 W =  3.62455 b =  2.15483
Epoch:  0550 loss =  1.136731029 W =  3.61127 b =  2.25091
Epoch:  0600 loss =  1.134940267 W =  3.59974 b =  2.33437
Epoch:  0650 loss =  1.133531928 W =  3.58973 b =  2.40688
Epoch:  0700 loss =  1.132419944 W =  3.58103 b =  2.46986
Epoch:  0750 loss =  1.131537557 W =  3.57347 b =  2.52458
Epoch:  0800 loss =  1.130834818 W =  3.5669 b =  2.57211
Epoch:  0850 loss =  1.130271792 W =  3.5612 b =  2.6134
Epoch:  0900 loss =  1.129818439 W =  3.55625 b =  2.64927
Epoch:  0950 loss =  1.129452229 W =  3.55194 b =  2.68042
Epoch:  1000 loss =  1.129154325 W =  3.5482 b =  2.70749
Epoch:  1050 loss =  1.128911495 W =  3.54496 b =  2.731
Epoch:  1100 loss =  1.128711581 W =  3.54213 b =  2.75143
Epoch:  1150 loss =  1.128546953 W =  3.53968 b =  2.76917
Epoch:  1200 loss =  1.128411174 W =  3.53755 b =  2.78458
Epoch:  1250 loss =  1.128297567 W =  3.53571 b =  2.79797
Epoch:  1300 loss =  1.128202677 W =  3.5341 b =  2.8096
Epoch:  1350 loss =  1.128123403 W =  3.5327 b =  2.81971
Epoch:  1400 loss =  1.128056765 W =  3.53149 b =  2.82849
Epoch:  1450 loss =  1.128000259 W =  3.53044 b =  2.83611
Epoch:  1500 loss =  1.127952814 W =  3.52952 b =  2.84274
Epoch:  1550 loss =  1.127912283 W =  3.52873 b =  2.84849
Epoch:  1600 loss =  1.127877355 W =  3.52804 b =  2.85349
Epoch:  1650 loss =  1.127847791 W =  3.52744 b =  2.85783
Epoch:  1700 loss =  1.127822518 W =  3.52692 b =  2.8616
Epoch:  1750 loss =  1.127801418 W =  3.52646 b =  2.86488
Epoch:  1800 loss =  1.127782702 W =  3.52607 b =  2.86773
Epoch:  1850 loss =  1.127766728 W =  3.52573 b =  2.8702
Epoch:  1900 loss =  1.127753139 W =  3.52543 b =  2.87234
Epoch:  1950 loss =  1.127740979 W =  3.52517 b =  2.87421
Epoch:  2000 loss =  1.127731323 W =  3.52495 b =  2.87584
Epoch:  2050 loss =  1.127722263 W =  3.52475 b =  2.87725
Epoch:  2100 loss =  1.127714872 W =  3.52459 b =  2.87847
Epoch:  2150 loss =  1.127707958 W =  3.52444 b =  2.87953
Epoch:  2200 loss =  1.127702117 W =  3.52431 b =  2.88045
Epoch:  2250 loss =  1.127697825 W =  3.5242 b =  2.88126
Epoch:  2300 loss =  1.127693415 W =  3.52411 b =  2.88195
Epoch:  2350 loss =  1.127689362 W =  3.52402 b =  2.88255
Epoch:  2400 loss =  1.127686620 W =  3.52395 b =  2.88307
Epoch:  2450 loss =  1.127683759 W =  3.52389 b =  2.88352
Epoch:  2500 loss =  1.127680898 W =  3.52383 b =  2.88391
Epoch:  2550 loss =  1.127679348 W =  3.52379 b =  2.88425
Epoch:  2600 loss =  1.127677798 W =  3.52374 b =  2.88456
Epoch:  2650 loss =  1.127675653 W =  3.52371 b =  2.88483
Epoch:  2700 loss =  1.127674222 W =  3.52368 b =  2.88507
Epoch:  2750 loss =  1.127673268 W =  3.52365 b =  2.88526
Epoch:  2800 loss =  1.127672315 W =  3.52362 b =  2.88543
Epoch:  2850 loss =  1.127671123 W =  3.5236 b =  2.88559
Epoch:  2900 loss =  1.127670288 W =  3.52358 b =  2.88572
Epoch:  2950 loss =  1.127670050 W =  3.52357 b =  2.88583
Epoch:  3000 loss =  1.127669215 W =  3.52356 b =  2.88592
Epoch:  3050 loss =  1.127668500 W =  3.52355 b =  2.88599
Epoch:  3100 loss =  1.127668381 W =  3.52354 b =  2.88606
Epoch:  3150 loss =  1.127667665 W =  3.52353 b =  2.88615
Epoch:  3200 loss =  1.127667546 W =  3.52352 b =  2.88621
Epoch:  3250 loss =  1.127667069 W =  3.52351 b =  2.88626
Epoch:  3300 loss =  1.127666950 W =  3.5235 b =  2.8863
Epoch:  3350 loss =  1.127666354 W =  3.5235 b =  2.88633
Epoch:  3400 loss =  1.127666593 W =  3.5235 b =  2.88637
Epoch:  3450 loss =  1.127666593 W =  3.52349 b =  2.8864
Epoch:  3500 loss =  1.127666235 W =  3.52349 b =  2.88644
Epoch:  3550 loss =  1.127665997 W =  3.52348 b =  2.88646
Epoch:  3600 loss =  1.127665639 W =  3.52348 b =  2.88648
Epoch:  3650 loss =  1.127665639 W =  3.52348 b =  2.88649
Epoch:  3700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  3750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  3800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  3850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  3900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  3950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4000 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4050 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4100 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4150 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4200 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4250 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4300 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4350 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4400 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4450 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4500 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4550 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4600 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4650 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  4950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5000 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5050 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5100 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5150 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5200 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5250 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5300 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5350 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5400 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5450 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5500 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5550 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5600 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5650 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  5950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6000 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6050 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6100 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6150 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6200 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6250 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6300 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6350 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6400 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6450 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6500 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6550 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6600 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6650 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  6950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7000 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7050 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7100 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7150 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7200 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7250 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7300 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7350 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7400 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7450 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7500 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7550 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7600 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7650 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  7950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8000 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8050 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8100 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8150 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8200 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8250 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8300 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8350 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8400 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8450 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8500 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8550 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8600 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8650 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  8950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9000 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9050 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9100 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9150 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9200 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9250 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9300 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9350 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9400 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9450 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9500 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9550 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9600 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9650 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9700 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9750 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9800 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9850 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9900 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  9950 loss =  1.127665997 W =  3.52348 b =  2.8865
Epoch:  10000 loss =  1.127665997 W =  3.52348 b =  2.8865
Training cost =  1.12767 W =  3.52348 b =  2.8865 

1 个答案:

答案 0 :(得分:4)

您假设预测输出位于直线上的假设是不正确的。检查年份和价格的情节。 enter image description here

因此,您所采用的线性假设将尽可能满足尽可能多的输入点以降低成本,从而最佳地拟合直线。因此,当您测试超出范围的点时,它将在直线上进行预测,该直线最适合您提供的输入集。

现在,你提到了两个问题。

<强> 1。成本不会下降:尝试降低学习率。你的成本肯定会下降。

<强> 2。你年份= 11的输出是错误的:我上面提到的原因。你需要做的是你必须改变这个假设。包括一个平方项,然后检查它。示例:y = ax^2 + bx + c。你会更好地适应这个假设方程。