在使用Tensorflow训练CNN时如何解决'OutOfRangeError:序列结束'错误?

时间:2018-12-26 09:44:56

标签: python-3.x tensorflow tensorflow-datasets

我正在尝试使用自己的数据集训练CNN。我一直在使用tfrecord文件和tf.data.TFRecordDataset API处理我的数据集。它对我的训练数据集工作正常。但是,当我尝试批处理验证数据集时,出现了“ OutOfRangeError:序列结束”错误。通过Internet浏览后,我认为问题是由验证集的批处理大小引起的,我首先将其设置为32。但是在将其更改为2之后,代码运行了9个纪元,错误再次出现。

我使用输入函数来处理数据集,代码如下:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_one_shot_iterator()

    features, labels = iterator.get_next()

    return features, labels

,对于训练集,“ batch_size”设置为128,“ num_epochs”设置为“无”,这意味着可以无限重复。对于验证集,“ batch_size”设置为32(后来设置为2,仍然无法使用),“ num_epochs”设置为1,因为我只想一次通过验证集。 我可以确保验证集包含足够的数据。因为我尝试了下面的代码,但没有引发任何错误:

with tf.Session() as sess:
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    for i in range(450):
        sess.run([features, labels])
        print(labels.shape)

在上面的代码中,当我将数字450更改为500或更大时,它将引发'OutOfRangeError'。那可以确认我的验证数据集包含足够的数据,可用于450次迭代,批处理大小为32。

我尝试使用较小的批处理大小(即2)作为验证集,但仍然存在相同的错误。 我可以在input_fn中将验证集的“ num_epochs”设置为“ None”的情况下运行代码,但这似乎不是验证工作的方式。有什么帮助吗?

1 个答案:

答案 0 :(得分:1)

这种现象是正常的。从Tensorflow文档中:

  

如果迭代器到达数据集的末尾,则执行Iterator.get_next()操作将引发tf.errors.OutOfRangeError。此后,迭代器将处于无法使用的状态,如果要进一步使用它,则必须再次对其进行初始化。

设置dataset.repeat(None)时不会引发错误的原因是,由于数据集会无限重复,因此永远不会耗尽它。

要解决您的问题,应将代码更改为此:

n_steps = 450
...    

with tf.Session() as sess:
    # Training
    features, labels = input_fn(True, training_list, 32, 1, 1)

    for step in range(n_steps):
        sess.run([features, labels])
        ...
    ...
    # Validation
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    try:
        sess.run([features, labels])
        ...
    except tf.errors.OutOfRangeError:
        print("End of dataset")  # ==> "End of dataset"

您还可以对input_fn进行一些更改以在每个时期运行评估:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_initializable_iterator()
    return iterator

n_epochs = 10
freq_eval = 1

training_iterator = input_fn(True, training_list, 32, 1, 1)
training_features, training_labels = training_iterator.get_next()

val_iterator = input_fn(False, valid_list, 32, 1, 1)
val_features, val_labels = val_iterator.get_next()

with tf.Session() as sess:
    # Training
    sess.run(training_iterator.initializer)
    for epoch in range(n_epochs):
        try:
            sess.run([training_features, training_labels])
        except tf.errors.OutOfRangeError:
            pass

        # Validation
        if (epoch+1) % freq_eval == 0:
            sess.run(val_iterator.initializer)
            try:
                sess.run([val_features, val_labels])
            except tf.errors.OutOfRangeError:
                pass

如果您希望对幕后发生的事情有更好的了解,我建议您仔细研究this official guide