我有一个问题,我会尝试用一个例子来解释,以便更容易理解。
我想对橘子(O)和苹果(A)进行分类。出于技术/遗留原因(网络中的组件),每个批处理应该只有O或只有A示例。因此,在示例级别进行传统的改组是不可能/不够的,因为我无法承担包含O和A示例混合的批次。然而,某种改组是可取的,因为训练深度网络是一种常见的做法。
这些是我采取的步骤:
答案 0 :(得分:2)
如果你使用Dataset
api,它相当简单。只需压缩O
和A
批次,然后使用Dataset.map()
应用随机选择功能:
ds0 = tf.data.Dataset.from_tensor_slices([0])
ds0 = ds0.repeat()
ds0 = ds0.batch(5)
ds1 = tf.data.Dataset.from_tensor_slices([1])
ds1 = ds1.repeat()
ds1 = ds1.batch(5)
def rand_select(ds0, ds1):
rval = tf.random_uniform([])
return tf.cond(rval<0.5, lambda: ds0, lambda: ds1)
dataset = tf.data.Dataset()
dataset = dataset.zip((ds0, ds1)).map(lambda ds0, ds1: rand_select(ds0, ds1))
iterator = dataset.make_one_shot_iterator()
ds = iterator.get_next()
with tf.Session() as sess:
for _ in range(5):
print(sess.run(ds))
> [0 0 0 0 0]
[1 1 1 1 1]
[1 1 1 1 1]
[0 0 0 0 0]
[0 0 0 0 0]