使用tf.data.Dataset评估每N个步骤

时间:2018-04-15 22:47:24

标签: python tensorflow machine-learning

TensorFlow是否有某种方法可以使用tf.data.Dataset API在每N个培训步骤中自动评估评估集?目前,我的输入函数如下所示:

def train_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(train_x), train_y))

    return (
        dataset
        .repeat()
        .shuffle(len(train_x) * 1.33))
        .batch(128)
        .make_one_shot_iterator().get_next()
    )

def eval_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(eval_x), eval_y))

    return (
        dataset
        .batch(len(eval_x)) # to use the entire eval set
        .make_one_shot_iterator().get_next()
    )

并在tf.estimator.DNNRegressor这样的实例上调用它们:

est = tf.estimator.DNNRegressor(...)

est.train(input_fn=train_input_fn, steps=5000)
est.evaluate(input_fn=eval_input_fn, steps=1)

1 个答案:

答案 0 :(得分:0)

根据建议in this StackOverflow answer使用已弃用的tf.contrib.learn.monitors.ValidationMonitor解决。 ValidationMonitor仍然可以使用Estimator实用程序功能在monitors.replace_monitors_with_hooks上使用from tensorflow.contrib.learn.python.learn import monitors as monitor_lib est = tf.Estimator.DNNRegressor(...) validation_monitor = tf.contrib.learn.monitors.ValidationMonitor( input_fn=eval_input_fn, every_n_steps=100, ) list_of_monitors_and_hooks = [validation_monitor] hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, est) est.train( input_fn=input_fn_train, steps=1000, hooks=hooks )

这是我的实施:

db2advis