你好,我的数据集迭代器突然出现问题。我已经在堆栈溢出中看到了类似的问题,但是它们都无法解决我的问题,因此我将其发布在这里。
训练后创建验证迭代器时,我的代码可以完美运行。但是现在我想看看损耗在测试集上的表现如何,因此需要训练和测试数据集。无论如何,当我尝试运行我的代码时,它始终表示我的第二个迭代器尚未初始化,我相信它已经初始化了。我几乎对所有事情都进行了尝试,使用variable_scape,重命名变量等。如果有人可以看一下我的代码并告诉我我在哪里出错了?我从文档中的tensorflow example开始,非常了解如何根据字符串句柄创建迭代器。
def run(self, model="NN", use_gazemap=False):
# Input with gazemap or without
graph_input = self.projection(use_gazemap=self.use_gazemap)
self.predictions = self.classification_graph_nn(graph_input)
handle = tf.placeholder(tf.string, shape=[])
# Valid Dataset
valid_size = 65268
self.valid_iterator, self.valid_dataset = load_data("valid",
self.batch_size, "valid.tfrecord")
#Train Dataset
train_size = 58212
self.train_iterator, self.train_dataset = load_data("train",
self.batch_size, "train.tfrecord")
# Iterator
iterator = tf.data.Iterator.from_string_handle(handle,
self.train_dataset.output_types,
self.train_dataset.output_shapes)
next_element = iterator.get_next()
valid_handle = self.session.run(self.valid_iterator.string_handle())
training_handle = self.session.run(self.train_iterator.string_handle())
self.session.run(tf.global_variables_initializer())
run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
# Summary
with tf.variable_scope('logging'):
tf.summary.scalar('current_cost', self.loss)
tf.summary.scalar('learning_rate', self.learning_rate)
summary = tf.summary.merge_all()
training_writer = tf.summary.FileWriter(
'./logs/training', self.session.graph)
testing_writer = tf.summary.FileWriter('.logs/testing', self.session.graph)
# Training model
for epoch in range(hparams.num_epochs):
self.session.run(self.train_iterator.initializer)
for it in range(train_size / hparams.batch_size):
# Training
frames, c3d, labels, gaze_gt, gaze_pred = self.session.run(
next_element, feed_dict={handle: training_handle})
feed_dict = {self.c3d: c3d,
self.gazemap: gaze_gt, self.labels: labels}
loss, _, global_step, learning_rate, training_summary = self.session.run(
[self.loss, self.train_op, self.global_step, self.learning_rate, summary], feed_dict=feed_dict, options=run_options)
# Testing
frames, c3d, labels, gaze_gt, gaze_pred = self.session.run(next_element, feed_dict={handle: valid_handle})
feed_dict = {self.c3d: c3d,
self.gazemap: gaze_gt, self.labels: labels}
test_loss, testing_summary = self.session.run(
[self.loss, summary], feed_dict=feed_dict, options=run_options)
if global_step % self.steps_per_logprint == 0:
self.session.run(self.predictions,
feed_dict=feed_dict, options=run_options)
batch_score = self.evaluate(self.predictions.eval(
feed_dict=feed_dict), self.labels.eval(feed_dict=feed_dict))
答案 0 :(得分:0)
您必须像初始化训练迭代器一样初始化验证迭代器。 为此,请在纪元的开头添加以下行:
self.session.run(self.valid_iterator.initializer)