Tensorflow:如何确保每批中的所有样品都有不同的标签?

时间:2018-03-07 09:02:17

标签: python tensorflow keras

我想知道是否有办法在Tensorflow中生成批量生成约束。特别是,我想生产包含不同标签的批次。

假设我有五个可能的标签{A, B, C, D, E},我希望批量为(A, C, E, D, B)(B,E,D,C,A)。基本上,我想避免使用(A, A, D, E, C)(A, B, B, B, E)等标签进行批量处理。

1 个答案:

答案 0 :(得分:2)

实施您的要求

批处理只是从它获取的任何内容中提取BATCH_SIZE个样本并将它们打包在一起,所以从技术上讲,这是可能的。但是,这取决于您,确保batch()的输入按照您想要的方式排序。

执行此操作的最有效方法可能是拥有5个tf.data.Dataset个,每个都有一个特定标签,zip它们在一起以获得一个"批量"标签的数据集始终采用相同的顺序,然后.map上的.shuffle数据集,以获取批次的随机排列并将其提供给您的网络。

我也会在随机排列后输入data = [ tf.constant([chr(ord('A')+i), chr(ord('a')+i) ]) for i in range(5) ] per_label_datasets = [tf.data.Dataset.from_tensor_slices(d) for d in data] dataset = tf.data.Dataset.zip(tuple(per_label_datasets)) # now an item has shape len(per_label_datasets) and one item from each dataset = dataset.map(lambda *args : tf.random_shuffle(args)) # lambda needed because random_shuffle takes only one argument dataset = dataset.shuffle(10) # optional it = dataset.make_one_shot_iterator() batch = it.get_next() sess = tf.Session() print(sess.run(batch)) print(sess.run(batch)) ,但请确保网络始终不会以相同的顺序查看同一批次。

在代码中看起来像:

[b'a' b'c' b'd' b'e' b'b']
[b'C' b'B' b'A' b'D' b'E']

示例输出:

http://www.etf.com/etf-finder-funds-api//-aum/100/100/1

个人注释

我不知道你正在使用什么型号,我认为有些模型对此有意义,但在大多数模型中批次中的样品顺序毫无意义因为在计算损失时,结果在一个批次中被平均。所以,如果你真的需要这个,有办法做,但在开始编码管道之前确保你确实需要它。