这个问题在某种程度上是Produce balanced mini batch with Dataset API的扩展,并引用了tf.data.Dataset
文档中的interleave
函数。
上下文:
假设您具有以下条件:
n=4
个类的数据集然后我们可以按以下方式构造标记的数据集:
path_ds = tf.data.Dataset.from_tensor_slices(files)
indx_ds = tf.data.Dataset.from_tensor_slices(labels)
ds = tf.data.Dataset.zip((path_ds, indx_ds))
如果我想为n
类(其中n>2
与链接的SO问题不同)制作平衡的微型批次,则:
# assuming class index starts at 0
class_ds = tf.data.Dataset.range(0, n).map(lambda e: tf.cast(e, tf.int32))
ids = class_ds.interleave(
lambda index : filter_for_class(index, ds),
cycle_length=n, block_length=1
)
每个班级都会产生一个示例,其中:
def filter_for_class(class_index, dataset):
return dataset.filter(lambda path, label: tf.math.equal(label, class_index))
通常,如果上述示例中的b=1
(block_length
),则:
.interleave(
...,
cycle_length=n*b,
block_length=b
)
将确保只要n*b
可将我们的迷你批处理整除,我们就能看到偶数类(只要每个类有足够的数据)。
所以我的问题是,如何使用tf.data.Dataset
中的build操作产生不平衡的迷你批处理。
例如假设如果我的小批量生产具有m
个元素,那么对于我的n
类,我希望每个类都具有以下比率:
class_ratios = {
0: 0.6,
1: 0.1,
2: 0.2,
3: 0.1
}
# if m = 100 then 60 examples from class 0, 10 from class 1, 20 from class 2 and 10 from class 3
与先前链接的问题不同,有些限制是每个文件正好是一个记录,并且可以从文件/路径名中提取记录标签。
注意:如果来自一个类的数据用完,上述方法将导致最终批次不平衡