如何使用自定义生成器使tf.data.Dataset.from_generator产生批处理

时间:2018-08-01 09:48:01

标签: python tensorflow tensorflow-datasets

我想使用.html API。我期望的工作流程如下所示:

  • 输入图像是带有tf.data

  • 的5D张量
  • 第一层是3D卷积

我使用(batch_size, width, height, channels, frames)函数创建一个迭代器。后来我做了一个可初始化的迭代器。

我的代码如下:

tf.data.from_generator

我希望def custom_gen(): img = np.random.normal((width, height, channels, frames)) yield(img, img) # I train an autoencoder, so the x == y` dataset = tf.data.Dataset.batch(batch_size).from_generator(custom_generator) iter = dataset.make_initializable_iterator() sess = tf.Session() sess.run(iter.get_next()) 为我提供具有批处理大小的5D张量。但是,我什至试图用自己的iter.get_next()来产生批处理大小,但它不起作用。当我想使用输入形状为custom_generator的占位符初始化数据集时,遇到一个错误。

1 个答案:

答案 0 :(得分:2)

该示例中的Dataset构造过程格式不正确。应当按照Importing Data官方指南中确定的顺序进行操作:

  1. 应该调用基本的数据集创建函数或静态方法来建立原始的数据源(例如,静态方法from_slice_tensorsfrom_generatorlist_files ,...)。
  2. 这时,可以通过链接适配器方法(例如batch)来应用转换

因此:

dataset = tf.data.Dataset.from_generator(custom_generator).batch(batch_size)