如何有效地从Tensorflow中的TFRecords加入数据

时间:2019-05-23 18:40:50

标签: python tensorflow tfrecord tpu

TFRecord上训练TensorFlow模型时,我需要有效地加入少量数据。如何使用已解析的TFRecord中的信息进行查找?

更多详细信息:

我正在使用TFRecords在大型数据集上训练卷积网络。每个TFRecord都包含原始图像以及目标标签,以及有关该图像的一些元数据。培训的一部分是,我需要使用特定于图像分组的meanstd来标准化图像。过去,我已经将meanstd硬编码到TFRecord中。然后像在我的parse_example中那样使用它,它被用来映射在我的Dataset中的input_fn上,就像这样:

def parse_example(..):
    # ...
    parsed = tf.parse_single_example(value, keys_to_features)
    image_raw = tf.decode_raw(parsed['image/raw'], tf.uint16)
    image = tf.reshape(image_raw, image_shape)
    image.set_shape(image_shape)

    # pull hardcoded pixels mean and std from the parsed TFExample
    mean = parsed['mean']
    std = parsed['std']

    image = (tf.cast(image, tf.float32) - mean) / std

    # ...

    return image, label

尽管上述方法可以缩短培训时间,但它的局限性在于我经常想更改我使用的meanstd。与其将meanstd写入TFRecord中,我不希望在训练时查找适当的摘要统计信息。这意味着当我训练时,我有一个小的python字典,可以使用有关从TFRecord解析的图像的信息来查找适当的摘要统计信息。我遇到的问题是我似乎无法在我的tensorflow图中使用此python字典。如果我尝试直接进行查找,则无法使用,因为我有张量对象而不是实际的图元。这是合理的input_fn正在做符号操作来构造TensorFlow的计算图(对吗?)。我该如何解决?

我尝试过的一件事是从字典中创建查找表,如下所示:

def create_channel_hashtable(keys, values, default_val=-1):
    initializer = tf.contrib.lookup.KeyValueTensorInitializer(keys, values)
    return tf.contrib.lookup.HashTable(initializer, default_val)

可以创建哈希表,并在parse_example函数中使用哈希表进行查找。所有这些都“有效”,但是却过分地减慢了训练速度。可能值得注意的是,此培训是在TPU上进行的。使用从TFRecord开始使用值的原始方法,训练很快,并且不受IO的限制,但是当使用哈希查找时,这种情况会改变。建议如何处理这些情况?重新打包TFRecord是可行的,但当要查找的数据很小且可以提高效率时,这似乎很愚蠢。

1 个答案:

答案 0 :(得分:0)

这种问题涵盖了本主题:

How can I merge multiple tfrecords file into one file?

似乎您将TFRecords保存到文件中,然后使用TFRecordDataset将它们全部拉到一个数据集中。我所链接的上述问题的答案中给出的代码是:

dataset = tf.data.TFRecordDataset(filenames_to_read,
    compression_type=None,    # or 'GZIP', 'ZLIB' if compress you data.
    buffer_size=10240,        # any buffer size you want or 0 means no buffering
    num_parallel_reads=os.cpu_count()  # or 0 means sequentially reading
)