在Tensorflow from_generator中的数据集之间进行运行时切换吗?

时间:2018-09-22 20:09:02

标签: python tensorflow

我有一个巨大的数据集(大约50 GB),并且正在使用类似以下的Python生成器加载它:

def data_generator(self, images_path):
    with open(self.temp_csv, 'r') as f:
        for image in f.readlines():
            # Something going on... 
            yield (X, y)

重要的是,我正在使用单个生成器来训练和验证数据,并且正在尝试在运行时更改 self.temp_csv 。但是,事情并没有按预期进行,并且通过更新变量 self.temp_csv 进行了设置,该变量应该在训练集和验证集之间切换,而不会调用 open ,并且我结束了一遍又一遍地遍历同一数据集。我想知道是否有可能使用 Dataset.from_generator ,并且在运行时,我切换到另一个数据集进行验证阶段。这是我指定发生器的方式。谢谢!

def get_data(self):

    with tf.name_scope('data'):

        data_generator = lambda: self.data_generator(images_path=self.data_path)

        my_data = tf.data.Dataset.from_generator(
        generator=data_generator,
        output_types=(tf.float32, tf.float32),
        output_shapes=(tf.TensorShape([None]), tf.TensorShape([None]))
        ).batch(self.batch_size).prefetch(2)

        img, self.label = my_data.make_one_shot_iterator().get_next()
        self.img = tf.reshape(img, [-1, CNN_INPUT_HEIGHT, CNN_INPUT_WIDTH, CNN_INPUT_CHANNELS])

1 个答案:

答案 0 :(得分:1)

您可以使用重新初始化的迭代器或可填充的迭代器在两个数据集之间进行切换,如official docs所示。

但是,如果您想使用生成器读取所有数据,然后创建一个训练和验证拆分,那么它就不是那么简单了。

如果您有单独的验证文件,则可以简单地创建一个新的验证数据集并使用上面显示的迭代器。 如果不是这种情况,则skip()和take()之类的方法可以帮助您拆分数据,但是需要考虑如何进行良好的拆分。