Tensorflow - 以“批次级别”而非“示例级别”进行混洗

时间:2018-02-11 02:40:52

标签: tensorflow batch-processing shuffle

我有一个问题,我会尝试用一个例子来解释,以便更容易理解。

我想对橘子(O)和苹果(A)进行分类。出于技术/遗留原因(网络中的组件),每个批处理应该只有O或只有A示例。因此,在示例级别进行传统的改组是不可能/不够的,因为我无法承担包含O和A示例混合的批次。然而,某种改组是可取的,因为训练深度网络是一种常见的做法。

这些是我采取的步骤:

  • 我首先需要将原始数据/示例转换为TFRecords。
  • 我改变了原始示例的顺序,然后我创建了单独的TFRecords,其中只包含改组的O示例,或者只包含洗牌的A示例。我们称之为“示例级别”改组。这是离线发生的,只发生一次。
  • 此时我有“干净批次”:仅包含O个示例的O-baches和仅包含A个示例的A-batches。
  • 我不想首先使用所有O批次向网络提供数据,然后依次为所有A批次提供网络。这可能对融合没什么帮助。
  • 我可以在“批次级别”中对这些批次进行洗牌,即不影响其内部?

1 个答案:

答案 0 :(得分:2)

如果你使用Dataset api,它相当简单。只需压缩OA批次,然后使用Dataset.map()应用随机选择功能:

ds0 = tf.data.Dataset.from_tensor_slices([0])
ds0 = ds0.repeat()
ds0 = ds0.batch(5)
ds1 = tf.data.Dataset.from_tensor_slices([1])
ds1 = ds1.repeat()
ds1 = ds1.batch(5)

def rand_select(ds0, ds1):
    rval = tf.random_uniform([])
    return tf.cond(rval<0.5, lambda: ds0, lambda: ds1)

dataset = tf.data.Dataset()
dataset = dataset.zip((ds0, ds1)).map(lambda ds0, ds1: rand_select(ds0, ds1))
iterator = dataset.make_one_shot_iterator()
ds = iterator.get_next()

with tf.Session() as sess:
    for _ in range(5):
        print(sess.run(ds))

> [0 0 0 0 0]
  [1 1 1 1 1]
  [1 1 1 1 1]
  [0 0 0 0 0]
  [0 0 0 0 0]