Tensorflow从字符串句柄创建第二个迭代器-GetNext()失败,因为未初始化

时间:2018-07-07 06:49:51

标签: python tensorflow

你好,我的数据集迭代器突然出现问题。我已经在堆栈溢出中看到了类似的问题,但是它们都无法解决我的问题,因此我将其发布在这里。

训练后创建验证迭代器时,我的代码可以完美运行。但是现在我想看看损耗在测试集上的表现如何,因此需要训练和测试数据集。无论如何,当我尝试运行我的代码时,它始终表示我的第二个迭代器尚未初始化,我相信它已经初始化了。我几乎对所有事情都进行了尝试,使用variable_scape,重命名变量等。如果有人可以看一下我的代码并告诉我我在哪里出错了?我从文档中的tensorflow example开始,非常了解如何根据字符串句柄创建迭代器。

 def run(self, model="NN", use_gazemap=False):
    # Input with gazemap or without

    graph_input = self.projection(use_gazemap=self.use_gazemap)
    self.predictions = self.classification_graph_nn(graph_input)

    handle = tf.placeholder(tf.string, shape=[])
    # Valid Dataset 
    valid_size = 65268
    self.valid_iterator, self.valid_dataset = load_data("valid",
                                                        self.batch_size, "valid.tfrecord")


    #Train Dataset 
    train_size = 58212 
    self.train_iterator, self.train_dataset = load_data("train",
                                                        self.batch_size, "train.tfrecord")


    # Iterator 
    iterator = tf.data.Iterator.from_string_handle(handle,
                self.train_dataset.output_types,
                self.train_dataset.output_shapes)

    next_element = iterator.get_next()


    valid_handle = self.session.run(self.valid_iterator.string_handle())
    training_handle = self.session.run(self.train_iterator.string_handle())
    self.session.run(tf.global_variables_initializer())


    run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)

    # Summary
    with tf.variable_scope('logging'):
        tf.summary.scalar('current_cost', self.loss)
        tf.summary.scalar('learning_rate', self.learning_rate)
        summary = tf.summary.merge_all()

    training_writer = tf.summary.FileWriter(
        './logs/training', self.session.graph)
    testing_writer = tf.summary.FileWriter('.logs/testing', self.session.graph)


    # Training model 
    for epoch in range(hparams.num_epochs):
        self.session.run(self.train_iterator.initializer)
        for it in range(train_size / hparams.batch_size):

            # Training
            frames, c3d, labels, gaze_gt, gaze_pred = self.session.run(
                    next_element, feed_dict={handle: training_handle})
            feed_dict = {self.c3d: c3d,
                             self.gazemap: gaze_gt, self.labels: labels}
            loss, _, global_step, learning_rate, training_summary = self.session.run(
                    [self.loss, self.train_op, self.global_step, self.learning_rate, summary], feed_dict=feed_dict, options=run_options)

            # Testing 
            frames, c3d, labels, gaze_gt, gaze_pred = self.session.run(next_element, feed_dict={handle: valid_handle})
            feed_dict = {self.c3d: c3d,
                                 self.gazemap: gaze_gt, self.labels: labels}
            test_loss, testing_summary = self.session.run(
                        [self.loss,  summary], feed_dict=feed_dict, options=run_options)

            if global_step % self.steps_per_logprint == 0:
                self.session.run(self.predictions,
                                 feed_dict=feed_dict, options=run_options)
                batch_score = self.evaluate(self.predictions.eval(
                    feed_dict=feed_dict), self.labels.eval(feed_dict=feed_dict))

1 个答案:

答案 0 :(得分:0)

您必须像初始化训练迭代器一样初始化验证迭代器。 为此,请在纪元的开头添加以下行:

self.session.run(self.valid_iterator.initializer)