如何在'input_fn'中使用tensorflow的迭代器'make_initializable_iterator'?

时间:2018-02-05 00:54:29

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我想用tf.estimator.Estimator训练我的模式并通过Dataset API加载我的数据。因为我的数据,例如'mnist',是一个数组(张量),所以我尝试用'tf加载它.data.Dataset.from_tensor_slices'。但我不知道如何在'input_fn'中初始化'make_initializable_iterator'。

如果我可以使用'make_one_shot_iterator'成功训练,但在训练前它会慢慢加载。并且“Higher-Level APIs in TensorFlow”是'input_fn'中'make_initializable_iterator'的一个很好的例子,但它需要从'input_fn'向其他函数返回'iterator_initializer_hook'。我想知道还有其他更好或更优雅的方式吗?

    def input_fn():

    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset iterator
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    iterator = dataset.make_one_shot_iterator()
    next_example = iterator.get_next()
    # Set runhook to initialize iterator

    return next_example

2 个答案:

答案 0 :(得分:6)

在TensorFlow版本1.5及更高版本中,当您从tf.estimator.Estimator返回tf.data.Dataset时,input_fn将自动创建并初始化可初始化的迭代器。这使您可以编写以下代码,而无需担心初始化或挂钩:

def input_fn():
    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset.
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    return dataset

答案 1 :(得分:0)

在您的代码中,添加以下内容:

      self.hooks.append(utils_hooks.DatasetHook(iter))

在run_loop.py中,在调用fn之前,添加此

 for hook in dataset_hooks:
        sess.run(hook.iterator().initializer)

那应该没事。