使用tf.io.parse_example从TFRecords中读取示例时,如何设置批处理大小?

时间:2019-07-02 01:46:12

标签: python tensorflow

我对tf.io.parse_example有疑问。

我阅读了有关从TFRecords导入数据的Tensorflow指南 here.

本指南提供了以下从TFRecords中读取和解析示例的方法:

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

我想使用tf.io.parse_example而不是tf.io.parse_single_example

因此,我将代码修改如下:

def _parse_function(batch_of_example_protos):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_example(batch_of_example_protos, features)
  return parsed_features

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
batch_size = 32
dataset = dataset.batch(batch_size).map(_parse_function)

我的问题:当我们在调用map(_parse_function)之前使用批处理时, 我们如何确定batch_size参数的适当值?

如果我们要优化性能,如何确定batch_size的值?

我已经尝试过以不同的批次大小测试性能,但是我想知道是否存在一种不太经验的评估方法。

谢谢!

其他详细信息

0 个答案:

没有答案