单个数据集的多个独立迭代器

时间:2018-09-22 19:57:35

标签: python tensorflow tensorflow-datasets

假设我有一个训练数据集“ data_train”,我想制作两个独立的迭代器,它们都在data_train上进行迭代。我将用来训练网络“ iter_train”的第一个迭代器,其中iter_train.get_next()的输出将是我对其进行训练的批处理。第二个迭代器将在我训练“ iter_eval”时用来评估整个训练数据集,以监视训练进度。

当前,如果我只有一个迭代器“ iter_single”,并且我想在某个时期的中途评估训练损失,则必须重置迭代器,使用iter_single评估整个数据集,然后从头开始进行训练具有iter_single的数据集。因此,除非我浪费时间在不进行任何操作的情况下遍历数据,否则我将不会完成上一个时期并忽略一半的数据集。

我已经尝试为一个数据集创建两个迭代器,但是,通过重置一个迭代器会重置另一个迭代器,这使得拥有两个迭代器毫无意义。

1 个答案:

答案 0 :(得分:0)

,如果您的数据集大小不大(大的意思是训练和验证数据都可以在训练和评估期间保存到内存中),则可以使用以下代码:

首先,读取并解析您的数据,然后将它们传递给Tensorflow数据集对象:

def get_image_dataset(dir_path, batch_size, split=0.7):

    # Parse data and return them in array format (Numpy)
    train_data, val_data = parse_data(dir_path, split)

    # Create the dataset for our train data
    train_data = tf.data.Dataset.from_tensor_slices(train_data)
    train_data = train_data.batch(batch_size)

    # Create the dataset for our test data
    val_data = tf.data.Dataset.from_tensor_slices(val_data)
    val_data = val_data.batch(batch_size)

    return train_data, val_data

第二,为火车和验证数据定义迭代器和初始化器:

def get_data():
    with tf.name_scope('data'):

        train_data, test_data =  get_image_dataset(self.batch_size)
        iterator = tf.data.Iterator.from_structure(output_types=train_data.output_types, output_shapes=train_data.output_shapes)

        # Define one iterator for your data
        img, self.label = iterator.get_next()

        # Example of application on MNIST dataset
        img = tf.reshape(img, [-1, CNN_INPUT_HEIGHT, CNN_INPUT_WIDTH, CNN_INPUT_CHANNELS])

        # Define two initializers for either train or test (validation) data
        self.train_init = iterator.make_initializer(train_data)
        self.test_init = iterator.make_initializer(test_data)

第三第四,在训练/测试网络的同时,使用训练/测试初始化Tensorflow图像这样的数据集:

培训

def train_network_one_epoch(...):

    # Initialize training
    sess.run(self.train_init)

    # Run training graph nodes

    return something

测试

def evaluate_network(...):

    # Initialize testing
    sess.run(self.test_init)

    # Run evaluation graph nodes

    return something

您可以查看this示例,该示例清楚地演示了此过程。