正如标题中所述,我需要具有特殊结构的批次:
1111
5555
2222
每个数字代表特征向量。因此,每个类N=4
({1,2,5}
)都有M=3
个向量,批处理大小为NxM=12
。
要完成此任务,我使用Tensorflow Dataset API和tfrecords:
M
个随机迭代器,并从每个迭代器中生成N
个特征向量我担心的是,我有数百个类(在功能中可能是数千个),并且从内存和性能的角度来看,为每个类存储迭代器看起来并不好。
有更好的方法吗?
答案 0 :(得分:2)
如果您具有按类排序的文件列表,则可以交错数据集:
import tensorflow as tf
N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
dataset = tf.data.Dataset.from_tensor_slices(record_files)
# Consider tf.contrib.data.parallel_interleave for parallelization
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=M, block_length=N)
# Consider passing num_parallel_calls or using tf.contrib.data.map_and_batch for performance
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)
编辑:
如果还需要改组,则可以在交错步骤中添加它:
import tensorflow as tf
N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
SHUFFLE_BUFFER_SIZE = 1000
dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
lambda record_file: tf.data.TFRecordDataset(record_file).shuffle(SHUFFLE_BUFFER_SIZE),
cycle_length=M, block_length=N)
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)
注意:如果没有更多剩余元素,则interleave
和batch
都将产生“部分”输出(请参阅文档)。因此,如果每个批次都具有相同的形状和结构对您来说很重要,则必须格外小心。至于批处理,您可以使用tf.contrib.data.batch_and_drop_remainder
,但据我所知,没有类似的替代方法,因此您必须确保所有文件都具有相同数量的示例,或者只是添加repeat
进行交错转换。
编辑2:
我得到了我想像的东西的概念证明:
import tensorflow as tf
NUM_EXAMPLES = 12
NUM_CLASSES = 9
records = [[str(i)] * NUM_EXAMPLES for i in range(NUM_CLASSES)]
M = 3
N = 4
dataset = tf.data.Dataset.from_tensor_slices(records)
dataset = dataset.interleave(tf.data.Dataset.from_tensor_slices,
cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
lambda data: tf.data.Dataset.from_tensor_slices(
tf.split(tf.random_shuffle(
tf.reshape(data, (NUM_CLASSES, N))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N,)))
batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
b = sess.run(batch)
print(b''.join(b).decode())
except tf.errors.OutOfRangeError: break
输出:
888866663333
555544447777
222200001111
222288887777
666655553333
000044441111
888822225555
666600004444
777733331111
与记录文件等效的是这样的(假设记录是一维向量):
import tensorflow as tf
NUM_CLASSES = 9
record_files = ['class{}.tfrecord'.format(i) for i in range(NUM_CLASSES)]
M = 3
N = 4
SHUFFLE_BUFFER_SIZE = 1000
dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
lambda file_name: tf.data.TFRecordDataset(file_name).shuffle(SHUFFLE_BUFFER_SIZE),
cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
lambda data: tf.data.Dataset.from_tensor_slices(
tf.split(tf.random_shuffle(
tf.reshape(data, (NUM_CLASSES, N, -1))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N, -1)))
这是通过每次阅读每个类的N
元素,并对结果块进行改组和拆分来实现的。它假设类的数目可以被M
整除,并且所有文件的记录数都相同。