所以我试图加载我训练有素的Tensorflow模型,但得到这个奇怪的错误,我无法找到有关此特定错误的任何答案。
这是我的救星电话:
curl -X POST -d '¶m1=1¶m2=2...' URL
这是我的恢复功能:
for inner_dict in dict1['c']:
for k, v in inner_dict:
do_something()
目标是使用已恢复的模型进行推理,但我在" saver = tf.train.import_meta_graph(trained_model_name)"中收到错误。一些帮助会很棒:)
错误代码:
with tf.Session(graph=self.graph) as sess:
saver = tf.train.Saver()
for i in range(self.c.epochs):
batch_data, batch_labels = self.get_batch(train_keys, self.c.doc_len, self.c.num_classes, batch_size=self.c.batch_size)
_, batch_loss = sess.run([self.optimizer, self.loss], feed_dict={self.input_data: batch_data, self.labels: batch_labels, self.dropout_rate: 0.5})
if (i % 2 == 0 and i != 0 or i == self.c.epochs-1):
saver.save(sess, save_model_file, global_step=2)
答案 0 :(得分:0)
在使用LSTMBlockFusedCell()时,我遇到了同样的问题。 解决方案在https://github.com/tensorflow/tensorflow/issues/23369
中# for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369
tf.contrib.rnn
# restore meta graph
meta_file = args.restore + '.meta'
loader = tf.train.import_meta_graph(meta_file, clear_devices=True)
...