我有这个保存的模型,我想恢复它。还原后,我想在一个新的数据集中对其进行评估,并使用Tensorflow Data输入管道进行输入。
import tensorflow as tf
from tfwrappers.tf_dataset import Dataset
tf.reset_default_graph()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_deep_model_2017.ckpt.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print("Restored Operations from MetaGraph:")
g = tf.get_default_graph()
batch_size = 128
num_steps = 4
train_init_op, test_init_op, Xtest, ytest = Dataset(year = 2017, batch_size = batch_size).build_iterator()
accuracy_update_op = g.get_tensor_by_name('LSTM/Accuracy/accuracy/update_op:0')
accuracy = g.get_tensor_by_name('LSTM/Accuracy/accuracy/value:0')
auc_update_op = g.get_tensor_by_name('LSTM/AUC/auc/update_op:0')
auc = g.get_tensor_by_name('LSTM/AUC/auc/value:0')
total_test_batch = int((400000/(num_steps * batch_size))+1)
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
sess.run(test_init_op)
for _ in range(total_test_batch):
sess.run([auc_update_op, accuracy_update_op])
accuracy_test= sess.run(accuracy)
AUC_test = sess.run(auc)
print("Test accuracy: {:>.2%}".format(accuracy_test), "Test AUC: {:>.2%}".format(AUC_test))
我得到的错误是FailedPreconditionError: GetNext() failed because the iterator has not been initialized
。但是,我已经有了初始化方法sess.run(test_init_op)
。
Dataset
模块非常基础,Python生成器从SQL数据库读取数据点并创建Dataset对象。
def build_iterator(self):
with tf.name_scope("Data"):
train_generator = PairGenerator(sql = '*SQL QUERY 1*'.format(self.year), max_rows=1600400)
validation_generator = PairGenerator(sql = '*SQL QUERY 2*'.format(self.year), max_rows=400000)
train_dataset = tf.data.Dataset.from_generator(lambda: train_generator, (tf.float32, tf.int32), (tf.TensorShape([self.num_steps, self.num_inputs]), tf.TensorShape([self.num_steps,])))
train_dataset=train_dataset.apply(tf.contrib.data.map_and_batch(map_func=lambda *x:(x[0], tf.cast(tf.one_hot(x[1], self.num_classes),tf.int32)), batch_size=self.batch_size, num_parallel_calls=self.num_parallel_calls, drop_remainder=False)).prefetch(self.prefetch_batch_buffer).repeat(self.num_epochs)
validation_dataset = tf.data.Dataset.from_generator(lambda: validation_generator, (tf.float32, tf.int32), (tf.TensorShape([self.num_steps, self.num_inputs]), tf.TensorShape([self.num_steps,])))
validation_dataset=validation_dataset.apply(tf.contrib.data.map_and_batch(map_func=lambda *x:(x[0], tf.cast(tf.one_hot(x[1], self.num_classes),tf.int32)), batch_size=self.batch_size, num_parallel_calls=self.num_parallel_calls, drop_remainder=False)).prefetch(self.prefetch_batch_buffer).repeat(self.num_epochs)
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
training_init_op = iterator.make_initializer(train_dataset, name='training_init_op')
validation_init_op = iterator.make_initializer(validation_dataset, name='validation_init_op')
X, y = iterator.get_next(name = 'get_next_datapoint')
return training_init_op, validation_init_op, X, y
大多数解决方案都是关于恢复迭代器并提供新数据集的。我无法提出解决方案。
编辑:忘记说此保存的模型是与另一个tf数据集对象一起训练的。