Tensorflow 1.10 TFRecordDataset-恢复TFRecords

时间:2018-08-28 19:17:53

标签: python tensorflow python-3.6 tensorflow-datasets tensorflow-estimator

注意:

  1. 此问题扩展到先前的question of mine。在该问题中,我询问了将一些伪数据存储为ExampleSequenceExample的最佳方法,以试图了解哪种数据更类似于提供的伪数据。我提供了ExampleSequenceExample结构的明确表述,并在回答中提供了一种编程的方式。

  2. 因为这仍然是很多代码,所以我提供了一个Colab(由Google托管的交互式jupyter笔记本)文件,您可以自己尝试使用该代码来提供帮助。所有必要的代码均已存在,并对其进行了慷慨的注释。

我正在尝试学习如何将我的数据转换为TF记录,因为声称的利益对于我的数据是值得的。但是,文档还有很多需要改进的地方,而试图深入学习的教程/博客(我见过)实际上只是接触表面或重新整理了现有的稀疏文档。

对于我的previous question中以及此处所考虑的演示数据,我编写了一个不错的类,其内容如下:

  • 具有n个通道的序列(在此示例中,它是基于整数的,具有固定长度的n个通道)
  • 带有软标签的类概率(在此示例中,有n个类且基于浮点数)
  • 一些元数据(在此示例中为一个字符串和两个浮点数)

并可以采用6种形式之一对数据进行编码:

  1. 示例,其中顺序通道/类以数字类型(在这种情况下为int64)分开,并附加了元数据
  2. 示例,其中顺序通道/类作为字节字符串(通过numpy.ndarray.tostring())分开,并附加了元数据
  3. 示例,将序列/类作为字节字符串转储,并附加了元数据

  4. SequenceExample,序列通道/类以数字类型分开,元数据作为上下文

  5. SequenceExample,其中序列通道作为字节字符串分开,而元数据作为上下文
  6. SequenceExample,序列和类作为字节字符串转储,而元数据作为上下文转储

这很好。

Colab中,我展示了如何将伪数据全部写入同一文件以及单独的文件中。

我的问题是如何恢复这些数据?

我在链接文件中尝试了4次尝试。

为什么TFReader与TFWriter处于不同的子软件包下?

1 个答案:

答案 0 :(得分:5)

已通过更新功能以包括形状信息并记住SequenceExample 未命名 FeatureLists来解决。

context_features = {
    'Name' : tf.FixedLenFeature([], dtype=tf.string),
    'Val_1': tf.FixedLenFeature([], dtype=tf.float32),
    'Val_2': tf.FixedLenFeature([], dtype=tf.float32)
}

sequence_features = {
    'sequence': tf.FixedLenSequenceFeature((3,), dtype=tf.int64),
    'pclasses'  : tf.FixedLenSequenceFeature((3,), dtype=tf.float32),
}

def parse(record):
  parsed = tf.parse_single_sequence_example(
        record,
        context_features=context_features,
        sequence_features=sequence_features
  )
  return parsed


filenames = [os.path.join(os.getcwd(),f"dummy_sequences_{i}.tfrecords") for i in range(3)]
dataset = tf.data.TFRecordDataset(filenames).map(lambda r: parse(r))

iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                           dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(dataset)

for _ in range(2):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(3):
    ne = sess.run(next_element)
    print(ne)