我的数据集有不同的目录,每个目录对应一个类。每个目录中有不同数量的.tfrecords。我的问题是如何从每个目录中采样5个图像(每个.tfrecord文件对应一个图像)? 我的另一个问题是,如何对这些目录中的5个进行采样,然后从每个目录中采样5个图像。
我只想用tf.data.dataset来做。所以我希望有一个数据集,我从中得到一个迭代器,而iterator.next()给了我一批25个图像,其中包含5个类中的5个样本。
答案 0 :(得分:10)
编辑:如果课程数量大于5,那么您可以使用新的tf.contrib.data.sample_from_datasets()
API(目前在tf-nightly
中提供,并且可在TensorFlow 1.9中使用)。
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]
CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)
# Build one dataset per class.
per_class_datasets = [
tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]
# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))
# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
lambda classes: tf.data.Dataset.from_tensor_slices(
tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))
# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
class_dataset)
# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)
如果您只有5个类,则可以为每个目录定义嵌套数据集,并使用Dataset.interleave()
进行组合:
# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())
# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
records = tf.data.TFRecordDataset(records)
# Zip the records with their class.
# NOTE: This part might not be necessary if the records contain information about
# their class that can be parsed from them.
return tf.data.Dataset.zip(
(records, tf.data.Dataset.from_tensors(class_label).repeat(None)))
# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)
答案 1 :(得分:2)
请在下面找到一个可能的解决方案。
为了演示,我使用python生成器而不是TFRecords作为输入(我假设您知道如何使用TF数据集来读取和解析每个文件夹中的文件。其他线程否则将覆盖此内容,例如here)。
import tensorflow as tf
import numpy as np
def get_class_generator(class_id, num_el, el_shape=(32, 32), el_dtype=np.int32):
""" Returns a dummy generator,
outputting "num_el" elements of a single class (input data & class label)
"""
def class_generator():
x = 0
for x in range(num_el):
element = np.ones(el_shape, dtype=el_dtype) * x
yield element, class_id
return class_generator
def concatenate_datasets(datasets):
""" Concatenate a list of datasets together.
Snippet by user2781994 (https://stackoverflow.com/a/49069420/624547)
"""
ds0 = tf.data.Dataset.from_tensors(datasets[0])
for ds1 in datasets[1:]:
ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
return ds0
num_classes = 11
class_batch_size = 3
num_classes_per_batch = 5
# note: using 3 instead of 5 for class_batch_size in this example
# just to distinguish between the 2 vars.
# Initializing per-class datasets:
# (note: replace tf.data.Dataset.from_generator(...) to suit your use-case
# e.g. tf.contrib.data.TFRecordDataset(glob.glob(perclass_tfrecords_path))
# .map(your_parsing_function)
class_datasets = [tf.data.Dataset
.from_generator(get_class_generator(
class_id, num_el=np.random.randint(1, 60)
# ^ simulating unequal number of samples per class
), (tf.int32, tf.int32), ([32, 32], []))
.repeat(-1)
.batch(class_batch_size)
for class_id in range(num_classes)]
# Initializing complete dataset:
dataset = (tf.data.Dataset
# Concatenating all the class datasets together:
.zip(tuple(class_datasets))
.flat_map(lambda *args: concatenate_datasets(args))
# Shuffling the class datasets:
.shuffle(buffer_size=num_classes)
# Flattening batches from shape (num_classes_per_batch, class_batch_size, ...)
# into (num_classes_per_batch * class_batch_size, ...):
.flat_map(lambda *args: tf.data.Dataset.from_tensor_slices(args))
# Returning correct number of el. (num_classes_per_batch * class_batch_size):
.batch(num_classes_per_batch * class_batch_size))
# Visualizing results:
next_batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(10):
batch = sess.run(next_batch)
print(">> batch {}".format(i))
print("- inputs shape: {} ; label shape: {}".format(batch[0].shape,batch[1].shape))
print("- class values: {}".format(batch[1]))
输出:
>> batch 0
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 1 1 1 0 0 0 10 10 10 2 2 2 9 9 9]
>> batch 1
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 2 2 2 3 3 3 5 5 5 6 6 6]
>> batch 2
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [ 9 9 9 8 8 8 4 4 4 3 3 3 10 10 10]
>> batch 3
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [7 7 7 8 8 8 6 6 6 6 6 6 2 2 2]
>> batch 4
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [1 1 1 0 0 0 1 1 1 8 8 8 5 5 5]
>> batch 5
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [2 2 2 4 4 4 9 9 9 5 5 5 5 5 5]
>> batch 6
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [0 0 0 7 7 7 3 3 3 9 9 9 7 7 7]
>> batch 7
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [10 10 10 10 10 10 1 1 1 6 6 6 7 7 7]
>> batch 8
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [4 4 4 3 3 3 5 5 5 6 6 6 3 3 3]
>> batch 9
- inputs shape: (15, 32, 32) ; label shape: (15,)
- class values: [8 8 8 9 9 9 2 2 2 8 8 8 0 0 0]