在检查点Tensorflow中找不到密钥

时间:2017-08-22 12:42:18

标签: python tensorflow

我正在使用深度学习来构建打字助手。我已经有一个预先训练好的模型,我试图加载它以预测下几个单词。

虽然代码可以在服务器上运行(模型已经过训练),但是当我尝试在系统上加载模型并尝试预测时。它产生了这个错误。

  

tensorflow / core / framework / op_kernel.cc:1152]未找到:密钥dq4st0 / multi_rnn_cell / cell_0 / basic_lstm_cell /在检查点中找不到偏差

预测代码如下

def text_output(args, bucket):
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args['gpu_mem'])

    with open(os.path.join(args['save_dir'], str(bucket)+'/config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args['save_dir'], str(bucket)+'/words_vocab.pkl'), 'rb') as f:
        words, vocab = cPickle.load(f)
    model = Model(saved_args, bucket, True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    tf.global_variables_initializer().run(session =sess)
    saver = tf.train.Saver(tf.global_variables())

    ckpt = tf.train.get_checkpoint_state(args['save_dir']+"/"+str(bucket))

    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    return args,model,words,vocab, sess

1 个答案:

答案 0 :(得分:4)

可能的问题是代码中的变量名称与检查点文件中的键不匹配。

我的建议是检查变量名称,如下所示:

  1. 在代码中获取变量名称:

    var_name_list = [v.name for v in tf.trainable_variables()]
    
  2. 获取检查点文件中的密钥:

    from tensorflow.python import pywrap_tensorflow
    
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    
  3. 您可以对它们进行比较,以检查ckpt文件中是否存在dq4st0/multi_rnn_cell/cell_0/basic_lstm_cell/biases