基本张量流示例 - 线的预测

时间:2017-02-28 19:22:05

标签: python tensorflow training-data

我试图用Tensorflow创建这个超级简单的例子,我显然不完全理解Tensorflow的API。

我有以下代码。它本来不是我的 - 我从一些演示中找到了它,但我不记得我发现它的位置,否则我会给作者一些信誉。道歉。

保存训练线模型

import tensorflow as tf
import numpy as np

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

# Create a session saver
saver = tf.train.Saver()

# Launch the graph.
sess = tf.Session() 

sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))
        saver.save(sess, 'linemodel')

好的,一切都很好。我只想加载模型然后查询我的模型以获得预测值。这是我尝试过的代码:

加载和查询训练线模型

# This is going to load the line model
import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
    v_ = sess.run(v)
    print("This is {} with value: {}".format(v.name, v_))
    # this works


# None of the below works
# Tried this as well
#fetches = {
#   "input": tf.constant(10, name='input')
#}

#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')

# print(sess.run(query_value))

这是一个非常基本的问题,但是我怎样才能传递一个值并使用我的线几乎就像一个函数。我是否需要改变线模型的构造方式?我的猜测是没有设置计算​​图,其中输出是我们可以获得的实际变量。它是否正确?如果是这样,我该如何修改这个程序?

1 个答案:

答案 0 :(得分:2)

您必须再次创建张量流图并将已保存的权重加载到其中。我在代码中添加了几行代码,它提供了所需的输出。请检查一下。

import tensorflow as tf
import numpy as np

sess = tf.Session() 
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()

# load saved weights into new variables
W = all_vars[0]
b = all_vars[1]

# build TF graph
x = tf.placeholder(tf.float32)
y = tf.add(tf.multiply(W,x),b)

# Session
init = tf.global_variables_initializer()
print(sess.run(all_vars))
sess.run(init)    
for i in range(2):
    x_ip = np.random.rand(10).astype(np.float32) # batch_size : 10
    vals = sess.run(y,feed_dict={x:x_ip})
    print vals

输出:

[array([ 0.1000001], dtype=float32), array([ 0.29999995], dtype=float32)]

[-0.21707924 -0.18646611 -0.00732027 -0.14248954 -0.54388255 -0.33952206  -0.34291503 -0.54771954 -0.60995424 -0.91694558]
[-0.45050886 -0.01207681 -0.38950539 -0.25888413 -0.0103816  -0.10003483 -0.04783082 -0.83299863 -0.53189355 -0.56571382]

我希望这会有所帮助。