恢复训练的Tensorflow模型KeyError:' BlockLSTM'

时间:2018-05-11 18:15:43

标签: python tensorflow lstm

所以我试图加载我训练有素的Tensorflow模型,但得到这个奇怪的错误,我无法找到有关此特定错误的任何答案。

这是我的救星电话:

curl -X POST -d '&param1=1&param2=2...' URL

这是我的恢复功能:

for inner_dict in dict1['c']: 
    for k, v in inner_dict:
        do_something()

目标是使用已恢复的模型进行推理,但我在" saver = tf.train.import_meta_graph(trained_model_name)"中收到错误。一些帮助会很棒:)

错误代码:

with tf.Session(graph=self.graph) as sess:
    saver = tf.train.Saver()
    for i in range(self.c.epochs):
        batch_data, batch_labels = self.get_batch(train_keys, self.c.doc_len, self.c.num_classes, batch_size=self.c.batch_size)

        _, batch_loss = sess.run([self.optimizer, self.loss], feed_dict={self.input_data: batch_data, self.labels: batch_labels, self.dropout_rate: 0.5})

        if (i % 2 == 0 and i != 0 or i == self.c.epochs-1):
            saver.save(sess, save_model_file, global_step=2)

1 个答案:

答案 0 :(得分:0)

在使用LSTMBlockFusedCell()时,我遇到了同样的问题。 解决方案在https://github.com/tensorflow/tensorflow/issues/23369

# for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369
tf.contrib.rnn
# restore meta graph
meta_file = args.restore + '.meta'
loader = tf.train.import_meta_graph(meta_file, clear_devices=True)
...