使用tf.estimator初始化tf.contrib.data.Iterator

时间:2017-08-10 17:52:49

标签: tensorflow tensorflow-datasets tensorflow-estimator

在使用tf.contrib.data.Iterator的情况下,如何初始化tf.estimator.Estimator

其中一个问题是输入图(tf图处理输入的部分)应该在intput_fn()中定义 - 因为tf.estimator创建了分隔图。

这个要求使得很难访问迭代器init ops并传递它们to tf.estimator(在以钩子的形式调用train/evaluate/predict时可以传递操作。)

2 个答案:

答案 0 :(得分:0)

使用rawdata(1:799)=[]作为钩子可以解决相同的问题。

SessionManager

答案 1 :(得分:0)

一种选择是将input_fn包装在另一个设置简单SessionRunHook init_hook的函数中。所有操作都在input_fn内定义,在与模型的其余部分相同的图形中调用,但是从中可以将iterator_init_op设置为init_hook上的属性。

def get_input_fn(mode="train"):
    init_hook = IteratorInitHook()

    def input_fn():
        ...
        iterator = dataset.make_initializable_iterator()
        init_hook.iterator_init_op = iterator.initializer

    return input_fn, init_hook

class IteratorInitHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord):
        session.run(self.iterator_init_op)

现在,当构造Experiment时,您可以获得这些输入函数和init挂钩,这些挂钩在创建训练/评估会话时被调用。它应该与estimator.train等效。

train_input_fn, train_init_hook = get_input_fn("train")
test_input_fn, test_init_hook = get_input_fn("test")

return tf.contrib.learn.Experiment(
    estimator=estimator,
    train_input_fn=train_input_fn,
    eval_input_fn=test_input_fn,
    train_monitors=[train_init_hook],
    eval_hooks=[test_init_hook],
)