我正在尝试使用tf.data输入管道来允许在批处理数据时在运行时选择数据集。以下应该允许我运行InitTrain / InitTest来选择要使用的数据集:
TrainData = tf.data.Dataset.from_generator(TrainGenerator,tf.float32)
TestData = tf.data.Dataset.from_generator(TestGenerator,tf.float32)
DataIterator = tf.data.Iterator.from_structure(tf.float32)
DataNext = DataIterator.get_next()
InitTrain = DataIterator.make_initializer(TrainData)
InitTest = DataIterator.make_initializer(TestData)
如果不使用迭代器,我可以使用以下内容来获取填充批处理:
TrainData.padded_batch(1000,someshape)
如何在保留选择输入数据源的能力的同时批量处理数据?
一个可能的解决方法是创建一个新的from_generator数据集,并创建一个生成器,在DataNext上调用sess.run来创建一个可以批处理的数据集但是这会导致分配运行调用,所以我怀疑这是它的方式意在使用。
答案 0 :(得分:0)
每个数据集的功能都独立于另一个,迭代器根据您使用的初始化程序在不同的数据集之间切换。因此,如果您想批量处理一个数据集而不是另一个数据集,则可以使用:
TrainData = tf.data.Dataset.from_generator(TrainGenerator,tf.float32)
TrainData = TrainData.batch(100)
TestData = tf.data.Dataset.from_generator(TestGenerator,tf.float32)
DataIterator = tf.data.Iterator.from_structure(tf.float32)
DataNext = DataIterator.get_next()
InitTrain = DataIterator.make_initializer(TrainData)
InitTest = DataIterator.make_initializer(TestData)
顺便提一下,您还可以使用两个管道不同的事实(如果您愿意)为每个数据集执行不同的映射。因此,例如,您可以对列车数据使用批量标准化,但不会对测试数据进行标准化(如果您已经在整个数据集中进行了标准化)。