如何在Tensorflow Data API中正确使用dataset.map

时间:2018-11-19 13:40:41

标签: python tensorflow tensorflow-datasets

我正在学习如何使用Tensorflow Data API,并努力了解映射的工作原理。对于上下文,我想加载图像数据集并将其发送到神经网络。

下面的MWE可以做到这一点(大小为10的伪数据集,read_image函数映射到该数据集)。

import tensorflow as tf
import numpy as np


def read_image(filename, label):
    return np.random.rand(8, 8, 1), label # simulate data load (generate random data)

# generate fake dataset of filenames (of size 10)
filenames = tf.constant(np.asarray(["file" + str(i) for i in range(10)]))
labels = tf.constant(np.asarray([2*i for i in range(10)]))

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(read_image)

dataset = dataset.repeat().batch(2)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)


X, y = iterator.get_next()
train_init_op = iterator.make_initializer(dataset)

with tf.Session() as session:
    tf.global_variables_initializer().run()
    session.run(train_init_op)
    for _ in range(10):
        print(session.run([X]))

运行此代码(应该什么也不做,只打印read_image生成的值)时,它最终总是具有相同的数据:read_image仅被调用一次。这是为什么 ?我使用了dataset.map,不是应该在数据集中的每个元素上调用它(此处为10)吗?

在此先感谢您的帮助或建议。

1 个答案:

答案 0 :(得分:0)

我对它的理解是错误的,实际上我想出了使它起作用的方法。该映射函数“集成”到Tensorflow图,因此实际上仅调用了一次。需要在其中使用TF操作。

如果read_image被替换为:

def read_image(filename, label):
    return tf.random_normal([4]), tf.random_normal([1])

它按预期方式工作(每次调用session.run时都会生成一个随机值)。