如何在TF 2.0数据集映射步骤中从张量获取和使用值?

时间:2019-03-15 14:48:55

标签: tensorflow tensorflow-datasets

首先,我使用TensorFlow Alpha 2.0。

我有我正在读取的TFRecords文件,每个文件都包含一个简短的视频剪辑,每个帧都编码为jpeg字节字符串以节省空间:

{
  'numframes': tf.io.FixedLenFeature([], tf.int64),
  'frames': tf.io.VarLenFeature(tf.string)
}

我在tf.data.Dataset管道中有一个映射步骤,可以成功解析每个示例:

def parse_tfrecord(p):
    return tf.io.parse_single_example(p, example_schema)

我的下一步是从numframes中读取帧数,并对frames.values[i]中的每个帧运行tf.io.decode_jpeg函数,其中i来自{{1} }:

range(numframes)

我的数据集管道的完整性:

def parse_jpegs(p):
    numframes = p['numframes']
    return tf.map_fn(tf.io.decode_jpeg, [p['frames'].values[i] for i in range(numframes)])

如果我排除了def dataset(): dataset = tf.data.Dataset.list_files("*.tfrecord") dataset = tf.data.TFRecordDataset(dataset) dataset = dataset.shuffle(1000).repeat() dataset = dataset.map(parse_tfrecord) dataset = dataset.map(parse_jpegs) return dataset 行,那么一切正常,向我显示类似dataset.map(parse_jpegs)

(请注意,numframe张量包含numpy值25。我可以使用tensor.numpy()方法在我的数据集管道之外获取该值)

尽管在那个map函数中,我无法调用.numpy()从张量中获取值,并且在打印张量本身时尚未对其进行评估或其他操作,因为还没有显示任何值。 / p>

解析数据集管道中所有这些帧的最佳方法是什么?

编辑:尝试获取numframe时,我收到的错误消息是parse_jpegs中的{'frames': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f394c285518>, 'numframes': <tf.Tensor: id=2937, shape=(), dtype=int64, numpy=25>}。这对我来说很有意义,为什么不能将张量解释为一个int,但是如何从该张量中获取值来设置范围呢?

我遇到的问题归结为每个“帧”对象都有不同数量的帧。如果我可以将TypeError: 'Tensor' object cannot be interpreted as an integer应用于该列表中的每个帧,而无需分别记录帧数,那么我可以,但是我这里有“ numframes”,因此我知道“框架”列表。

编辑:我将向其他可能认为有帮助的人提出这个问题,但最终我只返回了原始字节串,并在数据集API之外的单独的生成器函数中执行了decode_jpeg 。即使这样可能更慢,这种方式也更容易。

1 个答案:

答案 0 :(得分:0)

在我的特定情况下,我最终发现map_fn试图将我的输入张量转换为相同类型的输出张量。在这种情况下,tf.io.decode_jpeg接收一个字符串(字节)并输出一个uint8数组,这会引起问题。 tf.map_fn(... output_type=tf.uint8)的另一个论点似乎已经为我解决了!自从问了问题以来我继续修补问题以来,也许并不完全像我写的那样,但是我现在使它起作用了。