在写入和读取tfrecord文件之间得到不匹配的结果

时间:2018-10-31 10:26:31

标签: api tensorflow dataset pipeline

在这里,我使用此函数来写入多个tfrecord文件:

writer = tf.python_io.TFRecordWriter(save)
for pth, lb in tqdm(zip(piece_p, piece_l)):
    # mind that the path should be read into image data first
    # to convert the byteslist data format into raw bytes
    data = Image.open(pth)
    if resize is not None:
        data.thumbnail(resize, Image.ANTIALIAS)
    features = tf.train.Features(feature={
        'image': tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[data.tobytes()])),
        'label': tf.train.Feature(
            int64_list=tf.train.Int64List(value=[lb]))
    })
    example = tf.train.Example(features=features)

    # serialize the constructed data format before writing step
    writer.write(example.SerializeToString())
    sys.stdout.flush()
writer.close()

并使用以下代码解析二进制文件:

def parse_fn(serialized):
    features = {
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }
    parse_exp = tf.parse_single_example(serialized=serialized,
                                    features=features)
    labels = parse_exp['label']
    data = parse_exp['image']
    data = tf.decode_raw(data, tf.uint8)
    data = tf.cast(data, tf.float32)
    del parse_exp
    return data, labels

dataset = tf.data.Dataset.list_files(data_list, shuffle=True)
dataset = dataset.interleave(tf.data.TFRecordDataset,
                             cycle_length=file_num)
# dataset = tf.data.TFRecordDataset(data_list[0])

dataset = dataset.map(parse_fn, num_parallel_calls=4)

但是为什么标签和数据的数量总是不匹配...? 每次添加以下代码进行多次处理时……

dataset = dataset.batch(12)
dataset = dataset.repeat(1)
iterator = dataset.make_initializable_iterator()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
data, labels = iterator.get_next()

,标签数量始终为数据的一半。我的参数设置有问题吗?我很确定我的保存部分和阅读部分没有错...但是将它们组合在一起会出现一些问题。

0 个答案:

没有答案