KeyError:张量变量,请参考不存在的张量

时间:2018-03-08 06:33:18

标签: python tensorflow lstm

使用 LSTMCell 我训练了一个模型进行文本生成。我启动了tensorflow会话并使用 tf.global_variables_initializer()保存所有tensorflow变量。

import tensorflow as tf
sess = tf.Session()
//code blocks
run_init_op = tf.global_variables_intializer()
sess.run(run_init_op)
saver = tf.train.Saver()
#varible that makes prediction
prediction = tf.nn.softmax(tf.matmul(last,weight)+bias)
#feed the inputdata into model and trained
#saved the model
#save the tensorflow model
save_path= saver.save(sess,'/tmp/text_generate_trained_model.ckpt')
print("Model saved in the path : {}".format(save_path))

模型得到训练并保存其所有会话。链接以查看整个代码lstm_rnn.py

现在我加载了存储的模型并尝试为文档生成文本。所以,我使用以下代码恢复了模型

tf.reset_default_graph()
imported_data = tf.train.import_meta_graph('text_generate_trained_model.ckpt.meta')
with tf.Session() as sess:
    imported_meta.restore(sess,tf.train.latest_checkpoint('./'))

    #accessing the default graph which we restored
    graph = tf.get_default_graph()

    #op that we can be processed to get the output
    #last is the tensor that is the prediction of the network
    y_pred = graph.get_tensor_by_name("prediction:0")
    #generate characters
    for i in range(500):
        x = np.reshape(pattern,(1,len(pattern),1))
        x = x / float(n_vocab)
        prediction = sess.run(y_pred,feed_dict=x)
        index = np.argmax(prediction)
        result = int_to_char[index]
        seq_in = [int_to_char[value] for value in pattern]
        sys.stdout.write(result)
        patter.append(index)
        pattern = pattern[1:len(pattern)]

    print("\n Done...!")
sess.close()

我开始知道图中不存在预测变量。

  

KeyError:&#34;名称&#39;预测:0&#39;是指没有的Tensor   存在。图表中不存在操作&#39;预测&#39; <#34;

此处提供完整代码text_generation.py

虽然我保存了所有张量流变量,但预测张量未保存在张量流计算图中。我的 lstm_rnn.py 文件中有什么问题。

谢谢!

1 个答案:

答案 0 :(得分:3)

要使graph.get_tensor_by_name("prediction:0")起作用,您应该在创建它时对其进行命名。这就是你如何命名

prediction = tf.nn.softmax(tf.matmul(last,weight)+bias, name="prediction")

如果您已经训练了模型并且无法重命名张量,您仍然可以通过其默认名称获得该张量,如

y_pred = graph.get_tensor_by_name("Reshape_1:0")

如果Reshape_1不是张量的实际名称,则必须查看图中的名称并弄清楚。 您可以使用

进行检查
for op in graph.get_operations():
    print(op.name)