如何从tfrecords的目录创建tf.data.dataset?

时间:2018-05-15 18:05:50

标签: tensorflow tensorflow-datasets

我的数据集有不同的目录,每个目录对应一个类。每个目录中有不同数量的.tfrecords。我的问题是如何从每个目录中采样5个图像(每个.tfrecord文件对应一个图像)? 我的另一个问题是,如何对这些目录中的5个进行采样,然后从每个目录中采样5个图像。

我只想用tf.data.dataset来做。所以我希望有一个数据集,我从中得到一个迭代器,而iterator.next()给了我一批25个图像,其中包含5个类中的5个样本。

2 个答案:

答案 0 :(得分:10)

编辑:如果课程数量大于5,那么您可以使用新的tf.contrib.data.sample_from_datasets() API(目前在tf-nightly中提供,并且可在TensorFlow 1.9中使用)。

directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]

CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)


# Build one dataset per class.
per_class_datasets = [
    tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]

# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
    lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))

# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
    lambda classes: tf.data.Dataset.from_tensor_slices(
        tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))

# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
                                 class_dataset)

# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)

如果您只有5个类,则可以为每个目录定义嵌套数据集,并使用Dataset.interleave()进行组合:

# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())    

# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
  files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
  records = tf.data.TFRecordDataset(records)
  # Zip the records with their class. 
  # NOTE: This part might not be necessary if the records contain information about
  # their class that can be parsed from them.
  return tf.data.Dataset.zip(
      (records, tf.data.Dataset.from_tensors(class_label).repeat(None)))

# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
                                        cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)

答案 1 :(得分:2)

请在下面找到一个可能的解决方案。

为了演示,我使用python生成器而不是TFRecords作为输入(我假设您知道如何使用TF数据集来读取和解析每个文件夹中的文件。其他线程否则将覆盖此内容,例如here)。

import tensorflow as tf
import numpy as np

def get_class_generator(class_id, num_el, el_shape=(32, 32), el_dtype=np.int32):
    """ Returns a dummy generator, 
        outputting "num_el" elements of a single class (input data & class label) 
    """
    def class_generator():
        x = 0
        for x in range(num_el):
            element = np.ones(el_shape, dtype=el_dtype) * x
            yield element, class_id
    return class_generator


def concatenate_datasets(datasets):
    """ Concatenate a list of datasets together.
        Snippet by user2781994 (https://stackoverflow.com/a/49069420/624547)
    """
    ds0 = tf.data.Dataset.from_tensors(datasets[0])
    for ds1 in datasets[1:]:
        ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
    return ds0


num_classes = 11
class_batch_size = 3
num_classes_per_batch = 5
# note: using 3 instead of 5 for class_batch_size in this example 
#       just to distinguish between the 2 vars.

# Initializing per-class datasets:
# (note: replace tf.data.Dataset.from_generator(...) to suit your use-case
#        e.g. tf.contrib.data.TFRecordDataset(glob.glob(perclass_tfrecords_path))
#                            .map(your_parsing_function)
class_datasets = [tf.data.Dataset
                 .from_generator(get_class_generator(
                      class_id, num_el=np.random.randint(1, 60) 
                      # ^ simulating unequal number of samples per class
                      ), (tf.int32, tf.int32), ([32, 32], []))
                 .repeat(-1)
                 .batch(class_batch_size)
                  for class_id in range(num_classes)]

# Initializing complete dataset:
dataset = (tf.data.Dataset
           # Concatenating all the class datasets together:
           .zip(tuple(class_datasets))
           .flat_map(lambda *args: concatenate_datasets(args))
           # Shuffling the class datasets:
           .shuffle(buffer_size=num_classes)
           # Flattening batches from shape (num_classes_per_batch, class_batch_size, ...)
           # into (num_classes_per_batch * class_batch_size, ...):
           .flat_map(lambda *args: tf.data.Dataset.from_tensor_slices(args))
           # Returning correct number of el. (num_classes_per_batch * class_batch_size):
           .batch(num_classes_per_batch * class_batch_size))

# Visualizing results:
next_batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
    for i in range(10):
        batch = sess.run(next_batch)
        print(">> batch {}".format(i))
        print("- inputs shape: {} ; label shape: {}".format(batch[0].shape,batch[1].shape))
        print("- class values: {}".format(batch[1]))

输出:

>> batch 0
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 1  1  1  0  0  0 10 10 10  2  2  2  9  9  9]
>> batch 1
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 2 2 2 3 3 3 5 5 5 6 6 6]
>> batch 2
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 9  9  9  8  8  8  4  4  4  3  3  3 10 10 10]
>> batch 3
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [7 7 7 8 8 8 6 6 6 6 6 6 2 2 2]
>> batch 4
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [1 1 1 0 0 0 1 1 1 8 8 8 5 5 5]
>> batch 5
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [2 2 2 4 4 4 9 9 9 5 5 5 5 5 5]
>> batch 6
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 7 7 7 3 3 3 9 9 9 7 7 7]
>> batch 7
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [10 10 10 10 10 10  1  1  1  6  6  6  7  7  7]
>> batch 8
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [4 4 4 3 3 3 5 5 5 6 6 6 3 3 3]
>> batch 9
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [8 8 8 9 9 9 2 2 2 8 8 8 0 0 0]