构建tensorflow数据集迭代器以生成具有特殊结构的批次

时间:2018-07-31 11:36:19

标签: python tensorflow dataset

正如标题中所述,我需要具有特殊结构的批次:

1111
5555
2222

每个数字代表特征向量。因此,每个类N=4{1,2,5})都有M=3个向量,批处理大小为NxM=12

要完成此任务,我使用Tensorflow Dataset API和tfrecords:

  • 使用功能构建tfrecord,每个类1个文件
  • 为每个类创建Dataset实例,并为每个类初始化迭代器
  • 要生成批次,我会从迭代器列表中抽样M个随机迭代器,并从每个迭代器中生成N个特征向量
  • 然后我将功能堆叠在一起
  • ...
  • 批量生产

我担心的是,我有数百个类(在功能中可能是数千个),并且从内存和性能的角度来看,为每个类存储迭代器看起来并不好。

有更好的方法吗?

1 个答案:

答案 0 :(得分:2)

如果您具有按类排序的文件列表,则可以交错数据集:

import tensorflow as tf

N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)

dataset = tf.data.Dataset.from_tensor_slices(record_files)
# Consider tf.contrib.data.parallel_interleave for parallelization
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=M, block_length=N)
# Consider passing num_parallel_calls or using tf.contrib.data.map_and_batch for performance
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

编辑:

如果还需要改组,则可以在交错步骤中添加它:

import tensorflow as tf

N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
SHUFFLE_BUFFER_SIZE = 1000

dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
    lambda record_file: tf.data.TFRecordDataset(record_file).shuffle(SHUFFLE_BUFFER_SIZE),
    cycle_length=M, block_length=N)
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

注意:如果没有更多剩余元素,则interleavebatch都将产生“部分”输出(请参阅文档)。因此,如果每个批次都具有相同的形状和结构对您来说很重要,则必须格外小心。至于批处理,您可以使用tf.contrib.data.batch_and_drop_remainder,但据我所知,没有类似的替代方法,因此您必须确保所有文件都具有相同数量的示例,或者只是添加repeat进行交错转换。

编辑2:

我得到了我想像的东西的概念证明:

import tensorflow as tf

NUM_EXAMPLES = 12
NUM_CLASSES = 9
records = [[str(i)] * NUM_EXAMPLES for i in range(NUM_CLASSES)]
M = 3
N = 4

dataset = tf.data.Dataset.from_tensor_slices(records)
dataset = dataset.interleave(tf.data.Dataset.from_tensor_slices,
                             cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
    lambda data: tf.data.Dataset.from_tensor_slices(
        tf.split(tf.random_shuffle(
            tf.reshape(data, (NUM_CLASSES, N))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N,)))
batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            b = sess.run(batch)
            print(b''.join(b).decode())
        except tf.errors.OutOfRangeError: break

输出:

888866663333
555544447777
222200001111
222288887777
666655553333
000044441111
888822225555
666600004444
777733331111

与记录文件等效的是这样的(假设记录是一维向量):

import tensorflow as tf

NUM_CLASSES = 9
record_files = ['class{}.tfrecord'.format(i) for i in range(NUM_CLASSES)]
M = 3
N = 4
SHUFFLE_BUFFER_SIZE = 1000

dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
    lambda file_name: tf.data.TFRecordDataset(file_name).shuffle(SHUFFLE_BUFFER_SIZE),
    cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
    lambda data: tf.data.Dataset.from_tensor_slices(
        tf.split(tf.random_shuffle(
            tf.reshape(data, (NUM_CLASSES, N, -1))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N, -1)))

这是通过每次阅读每个类的N元素,并对结果块进行改组和拆分来实现的。它假设类的数目可以被M整除,并且所有文件的记录数都相同。

相关问题