我正在使用深度学习来构建打字助手。我已经有一个预先训练好的模型,我试图加载它以预测下几个单词。
虽然代码可以在服务器上运行(模型已经过训练),但是当我尝试在系统上加载模型并尝试预测时。它产生了这个错误。
tensorflow / core / framework / op_kernel.cc:1152]未找到:密钥dq4st0 / multi_rnn_cell / cell_0 / basic_lstm_cell /在检查点中找不到偏差
预测代码如下
def text_output(args, bucket):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args['gpu_mem'])
with open(os.path.join(args['save_dir'], str(bucket)+'/config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(args['save_dir'], str(bucket)+'/words_vocab.pkl'), 'rb') as f:
words, vocab = cPickle.load(f)
model = Model(saved_args, bucket, True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
tf.global_variables_initializer().run(session =sess)
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(args['save_dir']+"/"+str(bucket))
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
return args,model,words,vocab, sess
答案 0 :(得分:4)
可能的问题是代码中的变量名称与检查点文件中的键不匹配。
我的建议是检查变量名称,如下所示:
在代码中获取变量名称:
var_name_list = [v.name for v in tf.trainable_variables()]
获取检查点文件中的密钥:
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
您可以对它们进行比较,以检查ckpt文件中是否存在dq4st0/multi_rnn_cell/cell_0/basic_lstm_cell/biases
。