在使用tf.contrib.data.Iterator
的情况下,如何初始化tf.estimator.Estimator
?
其中一个问题是输入图(tf图处理输入的部分)应该在intput_fn()
中定义 - 因为tf.estimator创建了分隔图。
这个要求使得很难访问迭代器init ops
并传递它们to tf.estimator
(在以钩子的形式调用train/evaluate/predict
时可以传递操作。)
答案 0 :(得分:0)
使用rawdata(1:799)=[]
作为钩子可以解决相同的问题。
SessionManager
答案 1 :(得分:0)
一种选择是将input_fn
包装在另一个设置简单SessionRunHook init_hook
的函数中。所有操作都在input_fn
内定义,在与模型的其余部分相同的图形中调用,但是从中可以将iterator_init_op
设置为init_hook
上的属性。
def get_input_fn(mode="train"):
init_hook = IteratorInitHook()
def input_fn():
...
iterator = dataset.make_initializable_iterator()
init_hook.iterator_init_op = iterator.initializer
return input_fn, init_hook
class IteratorInitHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord):
session.run(self.iterator_init_op)
现在,当构造Experiment
时,您可以获得这些输入函数和init挂钩,这些挂钩在创建训练/评估会话时被调用。它应该与estimator.train
等效。
train_input_fn, train_init_hook = get_input_fn("train")
test_input_fn, test_init_hook = get_input_fn("test")
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_monitors=[train_init_hook],
eval_hooks=[test_init_hook],
)