我正在尝试使用tf.train.import_meta_graph()重用另一个.py文件中的图表
test.py是我训练/保存模型的代码。下面的代码是test.py
import tensorflow as tf
W = tf.Variable(tf.random_normal([1]))
b= tf.Variable(tf.random_normal([1]))
X= tf.placeholder(dtype='float32',shape=None)
Y= tf.placeholder(dtype='float32',shape=[None])
Y_ = W*X +b
Y_ =tf.identity(Y_,name="Y_")
tf.add_to_collection("Y_",Y_)
tf.add_to_collection("X",X)
cost = tf.reduce_mean(tf.square(Y_-Y))
train = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
if __name__ == "__main__":
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for i in range(10000):
sess.run([train],feed_dict={X:[1,2,3],Y:[2,4,6]})
print((sess.run(Y_,feed_dict={X:[1,2,3],Y:[2,4,6]})))
saver.save(sess,"debug/foo")
test2.py是我加载以前模型的代码 。以下代码为test2.py
import tensorflow as tf
import test
with tf.Session() as sess:
import_model = tf.train.import_meta_graph("debug/foo.meta")
import_model.restore(sess,"debug/foo")
print("restored")
result= sess.run(['Y_:0'],feed_dict={'X:0':[1,2,3]})
但是,在test2.py中,当我导入图表并尝试运行它时。它给我以下错误
TypeError: Cannot interpret feed_dict key as Tensor: The name 'X:0' refers
to a Tensor which does not exist. The operation, 'X', does not exist in the
graph.
我做错了什么?
我正在使用python 3.5和窗口7,而我的tensorflow版本是1.2
答案 0 :(得分:1)
张量不存在,因为X
没有名字。你应该写
X = tf.placeholder(dtype=tf.float32, name='X')
以下代码有效:
import tensorflow as tf
X = tf.Variable(tf.random_normal([1]))
Y = tf.placeholder(dtype=tf.float32, name='Y')
Z = tf.add(X, Y, name='sum')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(Z, {Y: 4})
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, '/tmp/model/my_model')
tf.reset_default_graph()
with tf.Session() as sess:
loader = tf.train.import_meta_graph('/tmp/model/my_model.meta')
sess.run(tf.global_variables_initializer())
loader = loader.restore(sess, '/tmp/model/my_model')
Z = tf.get_default_graph().get_tensor_by_name('sum:0')
print sess.run(Z, {'Y:0': 4})
print sess.run('sum:0', {'Y:0': 4})