在每个类的数据集中拆分张量流数据集

时间:2018-09-12 16:13:14

标签: python tensorflow tensorflow-datasets

我有一个从tfrecord文件创建的数据集。该数据集包含5个不同的类。

现在,我想从每个批次中创建具有固定数量的元素(例如8个)的批次。因此,应创建包含40个元素的批次,每个类包含8个元素。

tf.data是否可能?

1 个答案:

答案 0 :(得分:2)

最容易做的是(也许不是很方便):

a)准备5个不同的TFRecords,每个仅包含一个特定类别的元素。

b)创建5个不同的tf.data.TFRecordDataset实例,并因此创建5个不同的迭代器。

c)然后在主代码中:

iterators =  [....] # Store your iterators in a list
data = list(map(lambda x : x.get_next(), iterators))
data_to_use = tf.concat(....) # Concat your data in one single batch of `40` elements.

另一种方法(无需创建单独的数据集)

a)仅使用一个TFRecord。但要为其创建5个不同的实例

b)在每种情况下,都使用tf.data.filter(predicate) API的tf.data方法来过滤属于一个特定类的记录。为此,您将必须编写一个函数,该函数可以检查每个记录的类。

c)然后按照先前解决方案中的步骤c)