我有一个填充了标记数据的.tfrecords
文件。我想将X%用于培训,使用(1-X)%进行评估/测试。显然不应该有任何重叠。这样做的最佳方式是什么?
以下是我阅读tfrecords
的一小段代码。有没有办法让shuffle_batch
将数据分成训练和评估数据?我错了吗?
reader = tf.TFRecordReader()
files = tf.train.string_input_producer([TFRECORDS_FILE], num_epochs=num_epochs)
read_name, serialized_examples = reader.read(files)
features = tf.parse_single_example(
serialized = serialized_examples,
features={
'image': tf.FixedLenFeature([], tf.string),
'value': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image'], tf.uint8)
value = tf.decode_raw(features['value'], tf.uint8)
image, value = tf.train.shuffle_batch([image, value],
enqueue_many = False,
batch_size = 4,
capacity = 30,
num_threads = 3,
min_after_dequeue = 10)
答案 0 :(得分:1)
尽管这个问题是一年多以前问的,但最近我也遇到了类似的问题。
我对输入哈希使用tf.data.Dataset和过滤器。这是一个示例:
dataset = tf.data.TFRecordDataset(files)
if is_evaluation:
dataset = dataset.filter(
lambda r: tf.string_to_hash_bucket_fast(r, 10) == 0)
else:
dataset = dataset.filter(
lambda r: tf.string_to_hash_bucket_fast(r, 10) != 0)
dataset = dataset.map(tf.parse_single_example)
return dataset
到目前为止,我注意到的缺点之一是,每次评估都可能需要遍历10倍的数据才能收集足够的数据。为避免这种情况,您可能需要在数据预处理时分离数据。