恢复保存的模型并评估新的Tensorflow Data对象

时间:2019-03-18 19:21:49

标签: python tensorflow lstm tensorflow-datasets

我有这个保存的模型,我想恢复它。还原后,我想在一个新的数据集中对其进行评估,并使用Tensorflow Data输入管道进行输入。

import tensorflow as tf
from tfwrappers.tf_dataset import Dataset


tf.reset_default_graph()
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my_deep_model_2017.ckpt.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    print("Restored Operations from MetaGraph:")
    g = tf.get_default_graph()

    batch_size = 128
    num_steps = 4

    train_init_op, test_init_op, Xtest, ytest  = Dataset(year = 2017, batch_size = batch_size).build_iterator()

    accuracy_update_op = g.get_tensor_by_name('LSTM/Accuracy/accuracy/update_op:0')
    accuracy = g.get_tensor_by_name('LSTM/Accuracy/accuracy/value:0')

    auc_update_op = g.get_tensor_by_name('LSTM/AUC/auc/update_op:0')
    auc = g.get_tensor_by_name('LSTM/AUC/auc/value:0')

    total_test_batch = int((400000/(num_steps * batch_size))+1)
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    sess.run(test_init_op)
    for _ in range(total_test_batch):
        sess.run([auc_update_op, accuracy_update_op])
    accuracy_test= sess.run(accuracy)
    AUC_test = sess.run(auc)
    print("Test accuracy: {:>.2%}".format(accuracy_test), "Test AUC: {:>.2%}".format(AUC_test))

我得到的错误是FailedPreconditionError: GetNext() failed because the iterator has not been initialized。但是,我已经有了初始化方法sess.run(test_init_op)

Dataset模块非常基础,Python生成器从SQL数据库读取数据点并创建Dataset对象。

def build_iterator(self):
    with tf.name_scope("Data"):
        train_generator = PairGenerator(sql = '*SQL QUERY 1*'.format(self.year), max_rows=1600400)
        validation_generator = PairGenerator(sql = '*SQL QUERY 2*'.format(self.year), max_rows=400000)
        train_dataset = tf.data.Dataset.from_generator(lambda: train_generator, (tf.float32, tf.int32), (tf.TensorShape([self.num_steps, self.num_inputs]), tf.TensorShape([self.num_steps,])))
        train_dataset=train_dataset.apply(tf.contrib.data.map_and_batch(map_func=lambda *x:(x[0], tf.cast(tf.one_hot(x[1], self.num_classes),tf.int32)), batch_size=self.batch_size, num_parallel_calls=self.num_parallel_calls, drop_remainder=False)).prefetch(self.prefetch_batch_buffer).repeat(self.num_epochs)

        validation_dataset = tf.data.Dataset.from_generator(lambda: validation_generator, (tf.float32, tf.int32), (tf.TensorShape([self.num_steps, self.num_inputs]), tf.TensorShape([self.num_steps,])))
        validation_dataset=validation_dataset.apply(tf.contrib.data.map_and_batch(map_func=lambda *x:(x[0], tf.cast(tf.one_hot(x[1], self.num_classes),tf.int32)), batch_size=self.batch_size, num_parallel_calls=self.num_parallel_calls, drop_remainder=False)).prefetch(self.prefetch_batch_buffer).repeat(self.num_epochs)

        iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
        training_init_op = iterator.make_initializer(train_dataset, name='training_init_op')
        validation_init_op = iterator.make_initializer(validation_dataset, name='validation_init_op')

        X, y = iterator.get_next(name = 'get_next_datapoint')
    return training_init_op, validation_init_op, X, y

大多数解决方案都是关于恢复迭代器并提供新数据集的。我无法提出解决方案。

编辑:忘记说此保存的模型是与另一个tf数据集对象一起训练的。

0 个答案:

没有答案