如何使用线性回归训练的模型预测张量流中的答案?

时间:2018-01-04 06:26:27

标签: python tensorflow

我修改了一些非常简单的教程代码(下面)来训练线性回归模型(方式简单,因为我试图了解它是如何工作的)。当我得到预测结果时,训练工作正常。但是,我发现不可能找出输入一段测试数据(x)的语法,并让模型吐出预测的y值。我得到的错误是“无法为具有形状(?,1)的Tensor提供形状值(1,)。”

有人可以看看最后一节(#rebuild the graph,这样我们可以根据测试数据进行预测......)并指出我正确的方向吗?提前谢谢。

x的样本数据位于CSV

中每行的第0列
import numpy as np
import tensorflow as tf
import csv

//# Model linear regression y = Wx + b  
x = tf.placeholder(tf.float32, [None, 1], name='varx')  
W = tf.Variable(tf.zeros([1,1]), name='W')  
b = tf.Variable(tf.zeros([1]),name='b')  
product = tf.matmul(x,W)  
y = tf.placeholder(tf.float32,[None,1],name='vary')  
y = product + b  
y_ = tf.placeholder(tf.float32, [None, 1], name='varouty')  

//# Cost function sum((y_-y)**2)  
cost = tf.reduce_mean(tf.square(y_-y))

//# Training using Gradient Descent to minimize cost  
train_step = tf.train.GradientDescentOptimizer(0.0000001).minimize(cost)  

sess = tf.Session()  
init = tf.initialize_all_variables()  
sess.run(init)  
epochs = 500  

list1 = []  
with open(r'c:\temp\book1.csv','r') as f:  
    reader = csv.reader(f)  
    for row in reader:  
        i2 = int(row[0])  
        list1.append(i2)  


//#now loop through the list and add the items  
for i in range(epochs):  
  //# Create fake data for y = W.x + b where W = 2, b = 0
    i0 = list1[i]  
    xs = np.array([[i0]])   
    ys = np.array([[2*i0]])  
  //# Train  
    feed = { x: xs, y_: ys }  
    sess.run(train_step, feed_dict=feed)  
    print("xs=" + str(xs))  
    print("ys=" + str(ys))  
    print("After %d iterations:" % i)  
    print("W: %f" % sess.run(W))  
    print("b: %f" % sess.run(b))  



//# NOTE: W should be close to 2, and b should be close to 0

//#rebuild the graph so we can make a prediction from test data....
xg = tf.get_default_graph()  

x_input = xg.get_tensor_by_name('varx:0')  
y_output = xg.get_tensor_by_name('vary:0')  
W1 = xg.get_tensor_by_name('W:0')  
b1 = xg.get_tensor_by_name('b:0')  

with tf.Session(graph=xg) as sess1:  
    x_example = [3]  
    y_prediction = sess1.run(y_output,feed_dict={x_input:x_example})  
    print(y_prediction)  

0 个答案:

没有答案