tensorflow无法识别我导入的图形

时间:2017-10-04 02:45:11

标签: python tensorflow

我正在尝试使用tf.train.import_meta_graph()重用另一个.py文件中的图表

test.py是我训练/保存模型的代码。下面的代码是test.py

import tensorflow as tf

W = tf.Variable(tf.random_normal([1]))
b= tf.Variable(tf.random_normal([1]))
X= tf.placeholder(dtype='float32',shape=None)
Y= tf.placeholder(dtype='float32',shape=[None])

Y_ = W*X +b
Y_ =tf.identity(Y_,name="Y_")
tf.add_to_collection("Y_",Y_)
tf.add_to_collection("X",X)
cost = tf.reduce_mean(tf.square(Y_-Y))
train = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)

if __name__ == "__main__":
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        for i in range(10000):
            sess.run([train],feed_dict={X:[1,2,3],Y:[2,4,6]})
            print((sess.run(Y_,feed_dict={X:[1,2,3],Y:[2,4,6]})))
        saver.save(sess,"debug/foo")

test2.py是我加载以前模型的代码 。以下代码为test2.py

import tensorflow as tf
import test


with tf.Session() as sess:
    import_model = tf.train.import_meta_graph("debug/foo.meta")
    import_model.restore(sess,"debug/foo")
    print("restored")
    result= sess.run(['Y_:0'],feed_dict={'X:0':[1,2,3]})

但是,在test2.py中,当我导入图表并尝试运行它时。它给我以下错误

TypeError: Cannot interpret feed_dict key as Tensor: The name 'X:0' refers 
to a Tensor which does not exist. The operation, 'X', does not exist in the 
graph.

我做错了什么?

我正在使用python 3.5和窗口7,而我的tensorflow版本是1.2

1 个答案:

答案 0 :(得分:1)

张量不存在,因为X没有名字。你应该写

X = tf.placeholder(dtype=tf.float32, name='X')

以下代码有效:

import tensorflow as tf

X = tf.Variable(tf.random_normal([1]))
Y = tf.placeholder(dtype=tf.float32, name='Y')
Z = tf.add(X, Y, name='sum')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(Z, {Y: 4})

    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, '/tmp/model/my_model')

tf.reset_default_graph()

with tf.Session() as sess:
    loader = tf.train.import_meta_graph('/tmp/model/my_model.meta')
    sess.run(tf.global_variables_initializer())
    loader = loader.restore(sess, '/tmp/model/my_model')
    Z = tf.get_default_graph().get_tensor_by_name('sum:0')

    print sess.run(Z, {'Y:0': 4})
    print sess.run('sum:0', {'Y:0': 4})