TensorFlow模型恢复ValueError - 至少两个变量具有相同的名称

时间:2017-07-08 16:03:35

标签: python tensorflow mnist

我已经保存了一个模型,现在我正在尝试恢复它,在恢复它之后第一次正常工作但是当我按下“TEST'同一个正在运行的程序上的按钮,以测试另一个图像,它给出错误

ValueError:至少有两个变量具有相同的名称:Variable_2 / Adam

def train_neural_network(x):
    prediction = neural_network_model(x)#logits

    softMax=tf.nn.softmax_cross_entropy_with_logits(
            logits=prediction, labels=y)#prediction and original comapriosn
    cost = tf.reduce_mean(softMax)#total loss
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)#learning_rate=0.01
    hm_epochs = 20

    new_saver = tf.train.Saver()
    with tf.Session() as sess:
        global s
        s=sess
        sess.run(tf.global_variables_initializer())
        new_saver = tf.train.import_meta_graph('../MY_MODELS/my_MNIST_Model_1.meta')
        new_saver.restore(s, tf.train.latest_checkpoint('../MY_MODELS'))

        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        print('Accuracy:', accuracy.eval(
            {x: mnist.test.images, y: mnist.test.labels}))

1 个答案:

答案 0 :(得分:2)

您已加载的图表已包含推理所需的所有变量。您需要从保存的图表中加载accuracy等张量。在您的情况下,您在外部声明了相同的变量,这与图中的变量冲突。

在训练期间,如果您使用accuracy命名了张量name='accuracy',则可以使用以下方法从图表中加载get_tensor_by_name('accuracy:0')。在您的示例中,您还需要从图表中加载输入张量xy。您的代码应该是:

def inference():
   loaded_graph = tf.Graph()
   new_saver = tf.train.Saver()
   with tf.Session(graph=loaded_graph) as sess:
       new_saver = tf.train.import_meta_graph('../MY_MODELS/my_MNIST_Model_1.meta')
       new_saver.restore(s, tf.train.latest_checkpoint('../MY_MODELS'))

       #Get the tensors by their variable name 
       # Note: the names of the following tensors have to be declared in your train graph for this to work. So just name them appropriately.
      _accuracy = loaded_graph.get_tensor_by_name('accuracy:0')
      _x  = loaded_graph.get_tensor_by_name('x:0')
      _y  = loaded_graph.get_tensor_by_name('y:0')

       print('Accuracy:', _accuracy.eval(
        {_x: mnist.test.images, _y: mnist.test.labels}))