将列表写入和读取到TFRecord示例

时间:2019-06-02 20:50:09

标签: tensorflow tfrecord

我想将一个整数列表(或任何多维numpy矩阵)写入一个TFRecords示例。对于单个值或多个值的列表,我都可以无错误地创建TFRecord文件。我也知道如何从TFRecord文件中读取单个值,如下面从各种来源编译的代码示例所示。

# Making an example TFRecord

my_example = tf.train.Example(features=tf.train.Features(feature={
    'my_ints': tf.train.Feature(int64_list=tf.train.Int64List(value=[5]))
}))

my_example_str = my_example.SerializeToString()
with tf.python_io.TFRecordWriter('my_example.tfrecords') as writer:
    writer.write(my_example_str)

# Reading it back via a Dataset

featuresDict = {'my_ints': tf.FixedLenFeature([], dtype=tf.int64)}

def parse_tfrecord(example):
    features = tf.parse_single_example(example, featuresDict)
    return features

Dataset = tf.data.TFRecordDataset('my_example.tfrecords')
Dataset = Dataset.map(parse_tfrecord)
iterator = Dataset.make_one_shot_iterator()
with tf.Session() as sess:
   print(sess.run(iterator.get_next()))

但是如何从一个示例中读回值列表(例如[5,6])? featuresDict将功能定义为int64类型,当其中包含多个值并且出现以下错误时,该功能将失败:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: my_ints.  Can't parse serialized Example.

1 个答案:

答案 0 :(得分:0)

您可以通过使用tf.train.SequenceExample来实现。我已经编辑了您的代码以返回1D和2D数据。首先,创建放置在tf.train.FeatureList中的功能列表。我们将2D数据转换为字节。

vals = [5, 5]
vals_2d = [np.zeros((5,5), dtype=np.uint8), np.ones((5,5), dtype=np.uint8)]

features = [tf.train.Feature(int64_list=tf.train.Int64List(value=[val])) for val in vals]
features_2d = [tf.train.Feature(bytes_list=tf.train.BytesList(value=[val.tostring()])) for val in vals_2d]
featureList = tf.train.FeatureList(feature=features)
featureList_2d = tf.train.FeatureList(feature=features_2d)

为了获得正确的2D特征形状,我们需要提供上下文(非顺序数据),这是通过上下文字典完成的。

context_dict = {'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[vals_2d[0].shape[0]])), 
            'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[vals_2d[0].shape[1]])),
           'length': tf.train.Feature(int64_list=tf.train.Int64List(value=[len(vals_2d)]))}

然后将每个FeatureList放入tf.train.FeatureLists字典中。最后,将其与上下文字典一起放置在tf.train.SequenceExample中

my_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list={'1D':featureList,
                                                                                   '2D': featureList_2d}),
                                 context = tf.train.Features(feature=context_dict))
my_example_str = my_example.SerializeToString()
with tf.python_io.TFRecordWriter('my_example.tfrecords') as writer:
    writer.write(my_example_str)

要将其读回张量流,您需要使用tf.FixedLenSequenceFeature作为顺序数据,并使用tf.FixedLenFeature作为上下文数据。我们将字节转换回整数,并解析上下文数据,以恢复正确的形状。

# Reading it back via a Dataset
featuresDict = {'1D': tf.FixedLenSequenceFeature([], dtype=tf.int64),
           '2D': tf.FixedLenSequenceFeature([], dtype=tf.string)}
contextDict = {'height': tf.FixedLenFeature([], dtype=tf.int64),
          'width': tf.FixedLenFeature([], dtype=tf.int64),
          'length':tf.FixedLenFeature([], dtype=tf.int64)}

def parse_tfrecord(example):
    context, features = tf.parse_single_sequence_example(example, sequence_features=featuresDict,                                                       context_features=contextDict)

    height = context['height']
    width = context['width']
    seq_length = context['length']
    vals = features['1D']
    vals_2d = tf.decode_raw(features['2D'], tf.uint8)
    vals_2d = tf.reshape(vals_2d, [seq_length, height, width])
    return vals, vals_2d

Dataset = tf.data.TFRecordDataset('my_example.tfrecords')
Dataset = Dataset.map(parse_tfrecord)
iterator = Dataset.make_one_shot_iterator()
with tf.Session() as sess:
    print(sess.run(iterator.get_next()))

这将输出[5,5]和2D numpy数组的序列。这篇博客文章更深入地介绍了如何使用tfrecords https://dmolony3.github.io/Working%20with%20image%20sequences.html

定义序列