还原网络时,在还原的图中找不到操作

时间:2018-07-30 17:55:35

标签: tensorflow

我想使用TensorFlow 1.9在一个Python文件中训练一个神经网络,然后使用另一个Python文件恢复该网络。我尝试使用一个简单的示例来执行此操作,但是当我尝试加载“预测”操作时,出现错误。具体来说,错误是:KeyError: "The name 'prediction' refers to an Operation not in the graph."

下面是我的Python文件,用于训练和保存网络。它会生成一些示例数据并训练一个简单的神经网络,然后在每个时期保存该网络。

import numpy as np
import tensorflow as tf

input_data = np.zeros([100, 10])
label_data = np.zeros([100, 1])
for i in range(100):
    for j in range(10):
        input_data[i, j] = i * j / 1000
    label_data[i] = 2 * input_data[i, 0] + np.random.uniform(0.01)

input_placeholder = tf.placeholder(tf.float32, shape=[None, 10], name='input_placeholder')
label_placeholder = tf.placeholder(tf.float32, shape=[None, 1], name='label_placeholder')

x = tf.layers.dense(inputs=input_placeholder, units=10, activation=tf.nn.relu)
x = tf.layers.dense(inputs=x, units=10, activation=tf.nn.relu)
prediction = tf.layers.dense(inputs=x, units=1, name='prediction')

loss_op = tf.reduce_mean(tf.square(prediction - label_placeholder))
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_op)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch_num in range(100):
        _, loss = sess.run([train_op, loss_op], feed_dict={input_placeholder: input_data, label_placeholder: label_data})
        print('epoch ' + str(epoch_num) + ', loss = ' + str(loss))
        saver.save(sess, '../Models/model', global_step=epoch_num + 1)

下面是我的用于还原网络的Python文件。它加载输入和输出占位符,以及进行预测所需的操作。但是,即使我在上面的训练代码中将操作命名为prediction,下面的代码似乎也无法在加载的图形中找到该操作。

import tensorflow as tf
import numpy as np

input_data = np.zeros([100, 10])
for i in range(100):
    for j in range(10):
        input_data[i, j] = i * j / 1000

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('../Models/model-99.meta')
    saver.restore(sess, '../Models/model-99')
    graph = tf.get_default_graph()
    input_placeholder = graph.get_tensor_by_name('input_placeholder:0')
    label_placeholder = graph.get_tensor_by_name('label_placeholder:0')
    prediction = graph.get_operation_by_name('prediction')
    pred = sess.run([prediction], feed_dict={input_placeholder: input_data})

为什么此代码找不到此操作,我该怎么做才能更正我的代码?

1 个答案:

答案 0 :(得分:1)

您必须在加载脚本中修改一行(已通过tf 1.8测试):

prediction = graph.get_tensor_by_name('prediction/BiasAdd:0')

您必须指定要访问的张量,因为预测只是密集层的名称空间。您可以在使用prediction.name保存期间检查确切名称。还原时,请使用tf.get_tensor_by_name,因为您对值感兴趣,而不是对产生它的操作感兴趣。