我有一个从tfrecord文件创建的数据集。该数据集包含5个不同的类。
现在,我想从每个批次中创建具有固定数量的元素(例如8个)的批次。因此,应创建包含40个元素的批次,每个类包含8个元素。
tf.data是否可能?
答案 0 :(得分:2)
最容易做的是(也许不是很方便):
a)准备5个不同的TFRecords
,每个仅包含一个特定类别的元素。
b)创建5
个不同的tf.data.TFRecordDataset
实例,并因此创建5
个不同的迭代器。
c)然后在主代码中:
iterators = [....] # Store your iterators in a list
data = list(map(lambda x : x.get_next(), iterators))
data_to_use = tf.concat(....) # Concat your data in one single batch of `40` elements.
a)仅使用一个TFRecord。但要为其创建5
个不同的实例
b)在每种情况下,都使用tf.data.filter(predicate)
API的tf.data
方法来过滤属于一个特定类的记录。为此,您将必须编写一个函数,该函数可以检查每个记录的类。
c)然后按照先前解决方案中的步骤c)
。