将.tfrecords文件拆分为多个.tfrecords文件

时间:2019-02-04 15:25:49

标签: python tensorflow tensorflow-datasets tfrecord

是否有任何方法可以将.tfrecords文件直接拆分为多个.tfrecords文件,而无需回写每个数据集示例?

6 个答案:

答案 0 :(得分:4)

您可以使用如下功能:

__get()

例如,要将文件import tensorflow as tf def split_tfrecord(tfrecord_path, split_size): with tf.Graph().as_default(), tf.Session() as sess: ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size) batch = ds.make_one_shot_iterator().get_next() part_num = 0 while True: try: records = sess.run(batch) part_path = tfrecord_path + '.{:03d}'.format(part_num) with tf.python_io.TFRecordWriter(part_path) as writer: for record in records: writer.write(record) part_num += 1 except tf.errors.OutOfRangeError: break 分成100条记录的一部分,您可以这样做:

my_records.tfrecord

这将创建多个较小的记录文件split_tfrecord(my_records.tfrecord, 100) my_records.tfrecord.000等。

答案 1 :(得分:2)

使用.batch()而不是.shard(),以避免多次遍历数据集

更高效的方法(与使用tf.data.Dataset.shard()相比)将使用批处理:

import tensorflow as tf

ITEMS_PER_FILE = 100 # Assuming we are saving 100 items per .tfrecord file


raw_dataset = tf.data.TFRecordDataset('in.tfrecord')

batch_idx = 0
for batch in raw_dataset.batch(ITEMS_PER_FILE):

    # Converting `batch` back into a `Dataset`, assuming batch is a `tuple` of `tensors`
    batch_ds = tf.data.Dataset.from_tensor_slices(tuple([*batch]))
    filename = f'out.tfrecord.{batch_idx:03d}'

    writer = tf.data.experimental.TFRecordWriter(filename)
    writer.write(batch_ds)

    batch_idx += 1

答案 2 :(得分:2)

TensorFlow 2.x的非常有效的方法

正如@yongjieyongjie所提到的,您应该使用.batch()而不是.shard(),以避免根据需要在数据集上进行更多的迭代。 但是,如果您有一个非常大的数据集,对于内存来说太大了,它将失败(但不会出错),只给您几个文件和原始数据集的一部分。

首先,您应该对数据集进行批处理,并将每个文件要具有的记录数量用作批处理大小(我假设您的数据集已经是序列化格式,否则请参见here)。

dataset = dataset.batch(ITEMS_PER_FILE)

接下来要做的是使用生成器以避免内存不足。

def write_generator():
    i = 0
    iterator = iter(dataset)
    optional = iterator.get_next_as_optional()
    while optional.has_value().numpy():
        ds = optional.get_value()
        optional = iterator.get_next_as_optional()
        batch_ds = tf.data.Dataset.from_tensor_slices(ds)
        writer = tf.data.experimental.TFRecordWriter(save_to + "\\" + name + "-" + str(i) + ".tfrecord", compression_type='GZIP')#compression_type='GZIP'
        i += 1
        yield batch_ds, writer, i
    return

现在只需在常规for循环中使用生成器

for data, wri, i in write_generator():
    start_time = time.time()
    wri.write(data)
    print("Time needed: ", time.time() - start_time, "s", "\t", NAME_OF_FILES + "-" + str(i) + ".tfrecord")

只要一个文件适合原始存储,就可以正常工作。

答案 3 :(得分:1)

在tensorflow 2.0.0中,这将起作用:

import tensorflow as tf

raw_dataset = tf.data.TFRecordDataset("input_file.tfrecord")

shards = 10

for i in range(shards):
    writer = tf.data.experimental.TFRecordWriter(f"output_file-part-{i}.tfrecord")
    writer.write(raw_dataset.shard(shards, i))

答案 4 :(得分:0)

不均匀的分割

如果您想均匀地分成大小相等的文件,大多数其他答案都有效。这将适用于不均匀的分割:

# `splits` is a list of the number of records you want in each output file
def split_files(filename: str, splits: List[int]) -> None:
    dataset: tf.data.Dataset = tf.data.TFRecordDataset(filename)
    rec_counter: int = 0

    # An extra iteration over the data to get the size
    total_records: int = len([r for r in dataset])
    print(f"Found {total_records} records in source file.")

    if sum(splits) != total_records:
        raise ValueError(f"Sum of splits {sum(splits)} does not equal "
                         f"total number of records {total_records}")

    rec_iter:Iterator = iter(dataset)
    split: int
    for split_idx, split in enumerate(splits):
        outfile: str = filename + f".{split_idx}-{split}"
        with tf.io.TFRecordWriter(outfile) as writer:
            for out_idx in range(split):
                rec: tf.Tensor = next(rec_iter, None)
                rec_counter +=1
                writer.write(rec.numpy())
        print(f"Finished writing {split} records to file {split_idx}")

虽然我认为技术上 OP 询问了 without writing back each Dataset example(这就是这样做的),但这至少是在没有反序列化每个示例的情况下进行的。

对于非常大的文件来说有点慢。可能有一种方法可以修改其他一些基于批处理的答案,以便使用批处理输入读取但仍然写入不均匀的拆分,但我还没有尝试过。

答案 5 :(得分:0)

分成 N 个分割 (在 tensorflow 1.13.1 中测试)

import os
import hashlib
import tensorflow as tf
from tqdm import tqdm


def split_tfrecord(tfrecord_path, n_splits):
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    outfiles=[]
    for n_split in range(n_splits):
        output_tfrecord_dir = f"{os.path.splitext(tfrecord_path)[0]}"
        if not os.path.exists(output_tfrecord_dir):
            os.makedirs(output_tfrecord_dir)
        output_tfrecord_path=os.path.join(output_tfrecord_dir, f"{n_split:03d}.tfrecord")
        out_f = tf.io.TFRecordWriter(output_tfrecord_path)
        outfiles.append(out_f)

    for record in tqdm(dataset):
        sample = tf.train.Example()
        record = record.numpy()
        sample.ParseFromString(record)

        idx = int(hashlib.sha1(record).hexdigest(),16) % n_splits
        outfiles[idx].write(example.SerializeToString())

    for file in outfiles:
        file.close()