如何将SequenceExample对象的读取从tf.python_io.tf_record_iterator转换为tf.data.TFRecordDataset

时间:2019-08-20 19:49:20

标签: python tensorflow tfrecord

所以我有一个TFRecords格式的数据集,并且我试图将使用tf.python_io.tf_record_iterator读取的数据集转换为tf.data.TFRecordDataset。

除了不推荐使用tf.python_io.tf_record_iterator之外,这样做的主要原因是我希望能够使用tf.data.Dataset对象。

在TFRecords文件中,每个条目都是一个SequenceExample,特别是tensorflow.core.example.example_pb2.SequenceExample。

当前我正在通过此函数读取每个SequenceExample:

def read_records(record_path):
    records = []
    record_iterator = tf.python_io.tf_record_iterator(path=record_path)

    for string_record in record_iterator:
        example = tf.train.SequenceExample()
        example.ParseFromString(string_record)
        records.append(example)

    return records

打印记录会给我这种结构(由于长度而被截断):

context {
  feature {
    key: "framecount"
    value {
      int64_list {
        value: 10
      }
    }
  }
  feature {
    key: "label"
    value {
      int64_list {
        value: 1
      }
    }
  }
}
feature_lists {
  feature_list {
    key: "positions"
    value {
      feature {
        bytes_list {
          value: "\221\2206?\200dL?\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000"
        }
      }
    }
  }
}

现在,如果我尝试使用tf.data.TFRecordDataset执行此操作,则我的功能是:

def reader(file_path):
    dataset = tf.data.TFRecordDataset(file_path)
    for record in dataset:
        tf.io.parse_sequence_example(record)

    return dataset

给我一​​个值错误,表明我没有提供值或上下文功能。这是正确的,因为记录具有所述值。 (尽管看起来TFRecordDataset的数据输出与旧的记录迭代器不同,但我还尝试通过训练新的SequenceExample来遵循第一个功能的相同流程。)

鉴于此,我将如何正确生成我的sequenceExample?尽管从技术上讲我可以给它提供参数,但这似乎很不直观,特别是因为数据已经在记录中。

或者,(尽管这更像是一个创可贴修复),如何将第一个函数中的列表转换为张量流数据集对象?

1 个答案:

答案 0 :(得分:0)

好的,所以这有点棘手...

似乎tf.python_io.tf_record_iterator以SequenceExample.FromString()可以解析的直接二进制格式输出数据。另一方面,TFRecordDataset以直接张量格式返回数据。

由于我的意图是能够通过Dataset对象的内置生成器功能将数据点传递给模型,因此可以通过包装TFRecordDataset的输出来解决它。具体来说,我可以使用SequenceExample.FromString(datapoint.numpy())获得所需的输出。

这有点罗word,所以我的解决方案功能如下:

def reader(file_path):
    dataset = tf.data.TFRecordDataset(file_path)
    for record in dataset:
        record = tf.train.SequenceExample.FromString(record.numpy())
        yield record

这是我问题中第二个函数的直接修改