Tensorflow:如何确保每批中的所有样品都使用相同的标签?

时间:2018-02-11 01:20:20

标签: tensorflow batch-processing

我想知道是否有一些方法可以在Tensorflow中生成批量生成约束。例如,假设我们正在一个庞大的数据集上训练CNN来进行图像分类。是否有可能强制Tensorflow生成所有样本属于同一类的批次?比如,一批图像都用“Apple”标记,另一组图像都标有“橙色”。

我问这个问题的原因是我想做一些实验,看看不同级别的改组如何影响最终训练的模型。通常的做法是为CNN培训进行样本级改组,每个人都在这样做。我只是想亲自检查一下,从而获得更生动,更直接的知识。

谢谢!

1 个答案:

答案 0 :(得分:2)

可以使用

Dataset.filter()

labels = np.random.randint(0, 10, (10000))
data = np.random.uniform(size=(10000, 5))

ds = tf.data.Dataset.from_tensor_slices((data, labels))
ds = ds.filter(lambda data, labels: tf.equal(labels, 1)) #comment this line out for unfiltered case
ds = ds.batch(5)
iterator = ds.make_one_shot_iterator()
vals = iterator.get_next()

with tf.Session() as sess:
    for _ in range(5):
        py_data, py_labels = sess.run(vals)
        print(py_labels)

ds.filter()

 > [1 1 1 1 1]
   [1 1 1 1 1]
   [1 1 1 1 1]
   [1 1 1 1 1]
   [1 1 1 1 1]

没有ds.filter()

  > [8 0 7 6 3]
    [2 4 7 6 1]
    [1 8 5 5 5]
    [7 1 7 4 0]
    [7 1 8 0 0]

修改。以下代码显示如何使用可馈送迭代器动态执行批处理标签选择。请参阅" Creating an iterator"

labels = ['Apple'] * 100 + ['Orange'] * 100
data = list(range(200))
random.shuffle(labels)

batch_size = 4

ds_apple = tf.data.Dataset.from_tensor_slices((data, labels)).filter(
  lambda data, label: tf.equal(label, 'Apple')).batch(batch_size)
ds_orange = tf.data.Dataset.from_tensor_slices((data, labels)).filter(
  lambda data, label: tf.equal(label, 'Orange')).batch(batch_size)

handle = tf.placeholder(tf.string, [])
iterator = tf.data.Iterator.from_string_handle(
  handle, ds_apple.output_types, ds_apple.output_shapes)
batch = iterator.get_next()

apple_iterator = ds_apple.make_one_shot_iterator()
orange_iterator = ds_orange.make_one_shot_iterator()

with tf.Session() as sess:
  apple_handle = sess.run(apple_iterator.string_handle())
  orange_handle = sess.run(orange_iterator.string_handle())

  # loop and switch back and forth between apples and oranges
  for _ in range(3):
    feed_dict = {handle: apple_handle}
    print(sess.run(batch, feed_dict=feed_dict))
    feed_dict = {handle: orange_handle}
    print(sess.run(batch, feed_dict=feed_dict))

典型输出如下。请注意,data值在Apple和Orange批次中单调增加,表明迭代器未重置。

> (array([2, 3, 6, 7], dtype=int32), array([b'Apple', b'Apple', b'Apple', b'Apple'], dtype=object))
  (array([0, 1, 4, 5], dtype=int32), array([b'Orange', b'Orange', b'Orange', b'Orange'], dtype=object))
  (array([ 9, 13, 15, 19], dtype=int32), array([b'Apple', b'Apple', b'Apple', b'Apple'], dtype=object))
  (array([ 8, 10, 11, 12], dtype=int32), array([b'Orange', b'Orange', b'Orange', b'Orange'], dtype=object))
  (array([21, 22, 23, 25], dtype=int32), array([b'Apple', b'Apple', b'Apple', b'Apple'], dtype=object))
  (array([14, 16, 17, 18], dtype=int32), array([b'Orange', b'Orange', b'Orange', b'Orange'], dtype=object))