TF DATA API:如何生成对象集识别的张量流输入

时间:2018-05-27 07:53:00

标签: tensorflow tensorflow-datasets

考虑这个问题:从图像数据集(如ImageNet)中的随机主题中选择一个随机数量的样本作为Tensorflow图的输入元素,它用作对象集识别器。对于每个批次,每个类具有相同数量的样本以便于计算。但是对于一个类,不同的批次将具有不同数量的图像,即batch_0: finalResult = null; for each id in dfB: for query condition of this id: tempResult = query dfA union tempResult to finalResult = 2; batch_1000:num_imgs_per_cls = 3。

如果Tensorflow中存在现有功能,将非常感谢从头开始对整个过程的解释(如图像目录)。

1 个答案:

答案 0 :(得分:2)

@mrry here有一个非常相似的答案。

采样平衡批次

在面部识别中,我们经常使用三重损失(或类似的损失)来训练模型。对三元组进行抽样计算损失的通常方法是创建一个平衡的图像批次,其中我们有10个不同的类(即10个不同的人),每个类有5个图像。在此示例中,这使得总批量大小为50。

更一般地说,问题是对num_classes_per_batch(示例中为10)类进行采样,然后为每个类采样num_images_per_class(示例中为5)图像。总批量大小为:

batch_size = num_classes_per_batch * num_images_per_class

每个类都有一个数据集

处理许多不同课程(MS-Celeb中的100,000)的最简单方法是为每个班级创建一个数据集。
例如,您可以为每个类创建一个tfrecord,并创建如下数据集:

# Build one dataset per class.
filenames = ["class_0.tfrecords", "class_1.tfrecords"...]
per_class_datasets = [tf.data.TFRecordDataset(f).repeat(None) for f in filenames]

数据集中的样本

现在我们希望能够从这些数据集中进行采样。例如,我们希望批次中包含以下标签:

1 1 1 3 3 3 9 9 9 4 4 4

这相当于num_classes_per_batch=4num_images_per_class=3

为此,我们需要使用将在r1.9中发布的功能。该函数应调用tf.contrib.data.choose_from_datasets(有关此问题的讨论,请参阅here) 它应该看起来像:

def choose_from_datasets(datasets, selector):
    """Chooses elements with indices from selector among the datasets in `datasets`."""

因此,我们创建此selector,输出1 1 1 3 3 3 9 9 9 4 4 4并将其与datasets合并,以获取将输出平衡批次的最终数据集:

def generator(_):
    # Sample `num_classes_per_batch` classes for the batch
    sampled = tf.random_shuffle(tf.range(num_classes))[:num_classes_per_batch]
    # Repeat each element `num_images_per_class` times
    batch_labels = tf.tile(tf.expand_dims(sampled, -1), [1, num_images_per_class])
    return tf.to_int64(tf.reshape(batch_labels, [-1]))

selector = tf.contrib.data.Counter().map(generator)
selector = selector.apply(tf.contrib.data.unbatch())

dataset = tf.contrib.data.choose_from_datasets(datasets, selector)

# Batch
batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)

您可以使用每晚TensorFlow版本并使用DirectedInterleaveDataset作为解决方法来测试:

# The working option right now is 
from tensorflow.contrib.data.python.ops.interleave_ops import DirectedInterleaveDataset
dataset = DirectedInterleaveDataset(selector, datasets)

我也写过关于此解决方法的here