如何在无法执行Initializable_iterator的急切执行模式下动态提供tf.data.Dataset?

时间:2019-07-08 16:41:34

标签: python tensorflow

当我们需要逐个样本地馈送数据时,以动态方式通过数据集管道馈送数据的新方法(急于执行)是什么?

我有一个tf.data.Dataset,它执行一些预处理步骤并从生成器中读取数据,并在训练过程中从大型数据集中提取数据。

假设数据集表示为:

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
ds = ds.map(tf.square).shuffle(2).batch(2)
iterator = tf.data.make_one_shot_iterator(ds)

训练后,我想产生各种可视化效果,这些效果要求我一次通过网络输入一个样本进行推理。现在,我已经有了这个数据集预处理管道,我需要将其馈入原始样本,以针对网络输入进行适当的大小和形状调整。

这似乎是可初始化迭代器的用例:

placeholder = tf.placeholder(tf.float32, shape=None)
ds = tf.data.Dataset.from_tensor_slices(placeholder)
ds = ds.map(tf.square).shuffle(2).batch(2)
iterator = tf.data.make_initializable_iterator(ds) 
# now re-initialize for each sample
  

请记住,此示例中的map操作代表了很长的预处理操作序列,对于要输入的每个新数据样本都无法复制。

这不适用于急于执行的操作,您不能使用占位符。文档示例似乎都假设静态输入,例如此处的第一个示例。

我想到的唯一方法是使用队列和tf.data.Dataset.from_generator(...),该队列从我推送到的队列中读取数据,然后对数据进行预测。但这感觉既棘手,而且似乎容易陷入僵局,我尚未解决。

TF 1.14.0

1 个答案:

答案 0 :(得分:0)

我刚刚意识到这个问题的答案很简单:

  

只需创建一个新的数据集!

在非紧急模式下,下面的代码会降低性能,因为每个数据集操作都已添加到图形中而从未发布,而在非紧急模式下,我们具有可初始化的迭代器来解决该问题。

但是,在急切的执行模式下,像这样的张量流操作是短暂的,添加的迭代器不会被添加到全局图中,它们只会在不再引用时创建并死亡。赢取TF2.0之一!

下面的代码(可复制/粘贴可运行)演示:

import tensorflow as tf
import numpy as np
import time


tf.enable_eager_execution()

inp = np.ones(shape=5000, dtype=np.float32)

t = time.time()
while True:
    ds = tf.data.Dataset.from_tensors(inp).batch(1)
    val = next(iter(ds))
    assert np.all(np.squeeze(val, axis=0) == inp)
    print('Processing time {:.2f}'.format(time.time() - t))
    t = time.time()

该问题的动机紧随其后的是1.14,其中在Keras下以图形模式创建多个数据集操作构成内存泄漏。

https://github.com/tensorflow/tensorflow/issues/30448