找不到:在检查点中找不到键Variable_ <x>

时间:2018-07-30 06:29:16

标签: python-3.x tensorflow

我试图保存训练有素的模型,并稍后在另一个实例(函数)中使用它。但是,这以某种方式使我抛出变量未找到错误。通过SO和其他论坛筹集资金后,我了解到问题出在我的存储方式上。

    dictionary, reverse_dictionary = build_dataset(training_data)

    vocab_size = len(dictionary)
    n_input = 3
    n_hidden = 512

    # RNN output node weights and biases
    weights = {'out': tf.Variable(tf.random_normal([n_hidden, vocab_size]))}
    biases = {'out': tf.Variable(tf.random_normal([vocab_size]))}

    # tf Graph input
    x = tf.placeholder("float", [None, n_input, 1])
    y = tf.placeholder("float", [None, vocab_size])

    # RNN implementation in Tensorflow
    def RNN(x,weights,biases):       
        x = tf.reshape(x, [-1, n_input])       
        x = tf.split(x, n_input, 1)      
        rnn_cell = rnn.BasicLSTMCell(n_hidden)
        outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32)
        return tf.matmul(outputs[-1], weights['out']) + biases['out']

    pred = RNN(x, weights, biases)

    learning_rate = 0.001 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(cost)
    correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Initializing the variables
    init = tf.global_variables_initializer()

    training_iters = 1000
    display_step = 500

    saver = tf.train.Saver()

    # Launch the graph
    with tf.Session() as session:
        session.run(init)
        step = 0
        offset = random.randint(0, n_input+1)
        end_offset = n_input + 1
        acc_total = 0
        loss_total = 0

        while step < training_iters:

            if offset > (len(training_data)-end_offset):
                offset = random.randint(0, n_input+1)

            symbols_in_keys = [ [dictionary[ str(training_data[i])]] for i in range(offset, offset+n_input) ]
            symbols_in_keys = np.reshape(np.array(symbols_in_keys), [-1, n_input, 1])

            symbols_out_onehot = np.zeros([vocab_size], dtype=float)
            symbols_out_onehot[dictionary[str(training_data[offset+n_input])]] = 1.0
            symbols_out_onehot = np.reshape(symbols_out_onehot, [1, -1])

            _, acc, loss, onehot_pred = session.run([optimizer, accuracy, cost, pred], \
                                                    feed_dict={x: symbols_in_keys, y: symbols_out_onehot})
            loss_total += loss
            acc_total += acc
            if (step+1) % display_step == 0:
                print("Iter= " + str(step+1) + ", Average Loss= " + \
                      "{:.6f}".format(loss_total/display_step) + ", Average Accuracy= " + \
                      "{:.2f}%".format(100*acc_total/display_step))
                acc_total = 0
                loss_total = 0
                symbols_in = [training_data[i] for i in range(offset, offset + n_input)]
                symbols_out = training_data[offset + n_input]
                symbols_out_pred = reverse_dictionary[int(tf.argmax(onehot_pred, 1).eval())]
                print("%s - [%s] vs [%s]" % (symbols_in,symbols_out,symbols_out_pred))
            step += 1
            offset += (n_input+1)

        saver.save(session, 'userLocation/Model')

虽然生成了模型文件,但是当我尝试使用

还原模型时
saver = tf.train.Saver()
with tf.Session() as restored_session: 
    saver.restore(restored_session, 'userLocation/Model')

错误

tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_3 not found in checkpoint
     [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]]

关于保存时我缺少什么的任何提示。

1 个答案:

答案 0 :(得分:0)

我将在2个不同的部分对此进行解释-

  1. 将模型保存在tensorflow中时,它将图保存在一个文件中(通常扩展名为.meta),将变量张量保存在另一个文件中(通常是索引文件)。

  2. 现在,在导入时,您必须执行相同的两步过程-a)首先导入图形b)然后创建一个会话并导入变量。

这是示例代码-

import tensorflow as tf
import numpy as np

tf.set_random_seed(10)

#define graph location in variable
meta_file = 'userLocation/Model.meta'

#importing the graph
ns = tf.train.import_meta_graph(meta_file , clear_devices=True)

#create a session
with tf.Session().as_default() as sess:
     #import variables
     ns.restore(sess, meta_file[0:len(meta_file)-5])

     # for example, if you have 'x' tenbsor in graph
     x=tf.get_default_graph().get_tensor_by_name("x:0")
     .
     .
     .
     #Further processing/prediction etc