您如何在tf.data和TFRecords中使用参差不齐的张量?

时间:2019-03-09 23:28:39

标签: python tensorflow tensorflow-datasets

Tensorflow最近发布了参差不齐的张量:https://www.tensorflow.org/guide/ragged_tensors

但是没有有关如何将不完整的数据另存为TFRecord以及如何使用数据api还原它的文档。

2 个答案:

答案 0 :(得分:1)

很遗憾,没有 RaggedFeature 或同等功能。最好的选择是转换为稀疏(通过to_sparse())并将数据编码为 SparseFeature 。解码后,您可以通过from_sparse()构建器将其转换回参差不齐。

答案 1 :(得分:0)

一个参差不齐的张量需要两个数组:valuessomething定义应如何将values拆分为行(例如row_splitsrow_lengths,.. (请参见docs)。我的想法是将这两个数组作为两个特征存储在tf.Example 中,并在加载文件时创建参差不齐的张量。

例如:

import tensorflow as tf

def serialize_example(vals, lens):
  vals = tf.train.Feature(int64_list=tf.train.Int64List(value=vals))
  lens = tf.train.Feature(int64_list=tf.train.Int64List(value=lens))
  example = tf.train.Example(features=tf.train.Features(
      feature={'vals': vals, 'lens': lens})
  )
  return example.SerializeToString()

def parse_example(raw_example):
  example = tf.io.parse_single_example(raw_example, {
      'vals':tf.io.VarLenFeature(dtype=tf.int64),
      'lens':tf.io.VarLenFeature(dtype=tf.int64)
  })
  return tf.RaggedTensor.from_row_lengths(
      example['vals'].values, row_lengths=example['lens'].values
  )

ex1 = serialize_example([1,2,3,4,5,6,7,8,9,10], [3,2,5])
print(parse_example(ex1))  # <tf.RaggedTensor [[1, 2, 3], [4, 5], [6, 7, 8, 9, 10]]>
ex2 = serialize_example([1,2,3,4,5,6,7,8], [2,2,4])
print(parse_example(ex2))  # <tf.RaggedTensor [[1, 2], [3, 4], [5, 6, 7, 8]]>

从TFRecord文件创建数据集时,可以通过将parse_example传递给Dataset.map()函数来应用{}作为转换。