如何从tf.Dataset API为每个标签创建唯一批次

时间:2018-09-23 20:16:28

标签: python tensorflow tensorflow-datasets

让我们假设我的数据如下:

功能,标签

vector1,1

vector2,1

vector3,0

vector4,1,

vector5,0

我想获得一个批次,该批次应在批次大小条件(例如2)下代表每个迭代的每个标签中的随机1个样本。 例如,对于第一次迭代,我可以收到:

vector1,1

vector3,0

例如,对于第二次迭代,我可以收到:

vector4,1

vector5,0

以此类推...

您能否带我指导如何高效实施?

谢谢。

1 个答案:

答案 0 :(得分:3)

我将通过在每个标签上使用Datasetfilter分成两个流,然后例如依靠tf.contrib.data.choose_from_datasets将这些流合并回去。

一个优点是,您不会在处理过程中丢失任何样本(与您提供的示例相反)。

一个带有玩具数据的小例子:

import numpy as np
import tensorflow as tf

def gen():
  # generate random (value, label) pairs
  while True:
    yield (np.random.uniform(), np.random.randint(0, 2))

def split_and_merge(ds):
  return tf.contrib.data.choose_from_datasets(
    [ds.filter(lambda x, label: tf.equal(label, 0)),
     ds.filter(lambda x, label: tf.equal(label, 1))],
    tf.data.Dataset.range(2).repeat())

batch = (tf.data.Dataset
    .from_generator(gen, (tf.float32, tf.int32), (tf.TensorShape([]), tf.TensorShape([])))
    .apply(split_and_merge)
    .batch(2)
    .make_one_shot_iterator()
    .get_next())

sess = tf.InteractiveSession()
for _ in range(5):
  sess.run(batch)