我试图用Tensorflow创建这个超级简单的例子,我显然不完全理解Tensorflow的API。
我有以下代码。它本来不是我的 - 我从一些演示中找到了它,但我不记得我发现它的位置,否则我会给作者一些信誉。道歉。
import tensorflow as tf
import numpy as np
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
# Create a session saver
saver = tf.train.Saver()
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, 'linemodel')
好的,一切都很好。我只想加载模型然后查询我的模型以获得预测值。这是我尝试过的代码:
# This is going to load the line model
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
v_ = sess.run(v)
print("This is {} with value: {}".format(v.name, v_))
# this works
# None of the below works
# Tried this as well
#fetches = {
# "input": tf.constant(10, name='input')
#}
#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')
# print(sess.run(query_value))
这是一个非常基本的问题,但是我怎样才能传递一个值并使用我的线几乎就像一个函数。我是否需要改变线模型的构造方式?我的猜测是没有设置计算图,其中输出是我们可以获得的实际变量。它是否正确?如果是这样,我该如何修改这个程序?
答案 0 :(得分:2)
您必须再次创建张量流图并将已保存的权重加载到其中。我在代码中添加了几行代码,它提供了所需的输出。请检查一下。
import tensorflow as tf
import numpy as np
sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
# load saved weights into new variables
W = all_vars[0]
b = all_vars[1]
# build TF graph
x = tf.placeholder(tf.float32)
y = tf.add(tf.multiply(W,x),b)
# Session
init = tf.global_variables_initializer()
print(sess.run(all_vars))
sess.run(init)
for i in range(2):
x_ip = np.random.rand(10).astype(np.float32) # batch_size : 10
vals = sess.run(y,feed_dict={x:x_ip})
print vals
输出:
[array([ 0.1000001], dtype=float32), array([ 0.29999995], dtype=float32)]
[-0.21707924 -0.18646611 -0.00732027 -0.14248954 -0.54388255 -0.33952206 -0.34291503 -0.54771954 -0.60995424 -0.91694558]
[-0.45050886 -0.01207681 -0.38950539 -0.25888413 -0.0103816 -0.10003483 -0.04783082 -0.83299863 -0.53189355 -0.56571382]
我希望这会有所帮助。