Tensorflow 1.14+:使用Dataset API故意制作不平衡的迷你批处理

时间:2019-08-02 17:24:59

标签: python tensorflow tensorflow-datasets

这个问题在某种程度上是Produce balanced mini batch with Dataset API的扩展,并引用了tf.data.Dataset文档中的interleave函数。

上下文:

假设您具有以下条件:

  • 具有n=4个类的数据集
  • 每个文件对应一条记录的文件名列表
  • 每个文件的标签

然后我们可以按以下方式构造标记的数据集:

path_ds = tf.data.Dataset.from_tensor_slices(files)
indx_ds = tf.data.Dataset.from_tensor_slices(labels)
ds = tf.data.Dataset.zip((path_ds, indx_ds))

如果我想为n类(其中n>2与链接的SO问题不同)制作平衡的微型批次,则:

# assuming class index starts at 0
class_ds = tf.data.Dataset.range(0, n).map(lambda e: tf.cast(e, tf.int32))

ids = class_ds.interleave(
    lambda index : filter_for_class(index, ds),
    cycle_length=n, block_length=1
)

每个班级都会产生一个示例,其中:

def filter_for_class(class_index, dataset):
    return dataset.filter(lambda path, label: tf.math.equal(label, class_index))

通常,如果上述示例中的b=1block_length),则:

.interleave( 
   ...,
   cycle_length=n*b, 
   block_length=b
)

将确保只要n*b可将我们的迷你批处理整除,我们就能看到偶数类(只要每个类有足够的数据)。

所以我的问题是,如何使用tf.data.Dataset中的build操作产生不平衡的迷你批处理。

例如假设如果我的小批量生产具有m个元素,那么对于我的n类,我希望每个类都具有以下比率:

class_ratios = {
    0: 0.6,
    1: 0.1,
    2: 0.2,
    3: 0.1
}

# if m = 100 then 60 examples from class 0, 10 from class 1, 20 from class 2 and 10 from class 3

与先前链接的问题不同,有些限制是每个文件正好是一个记录,并且可以从文件/路径名中提取记录标签。

注意:如果来自一个类的数据用完,上述方法将导致最终批次不平衡

0 个答案:

没有答案