如何使用tf.data的可初始化迭代器和可重新初始化的插入器,并将数据馈送到估算器api?

时间:2018-08-01 04:50:05

标签: python tensorflow tensorflow-datasets tensorflow-estimator

所有正式的google教程都对所有estimator api实现使用一次快照迭代器,我找不到任何有关如何使用tf.data的可初始化迭代器和可重新初始化迭代器的文档,而不是一个快照迭代器。

有人可以告诉我如何使用tf.data的可初始化迭代器和可重新初始化插入器在train_data和test_data之间进行切换。我们需要运行一个会话来使用feed dict,并在可初始化的迭代器,其低级api及其令人困惑的如何使用它的方法中切换数据集estimator api体系结构的一部分

PS:我确实发现Google提到了 “注意:目前,单次迭代器是唯一可与Estimator一起使用的类型。”

但是社区内部有什么工作吗?还是出于某些原因我们应该坚持使用一次射击迭代器

2 个答案:

答案 0 :(得分:4)

要使用可初始化或可初始化的迭代器,必须创建一个继承自tf.train.SessionRunHook的类。然后,此类可以访问tf.estimator函数使用的会话。

这是一个可以满足您的需求的简单示例:

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

答案 1 :(得分:0)

或者您可以简单地使用tf.estimator.train_and_evaluate https://www.tensorflow.org/api_docs/python/tf/estimator/train_and_evaluate 它使您可以在训练过程中使用验证,而完全不必关心迭代器。