使用数据集API训练估算器少于一个纪元?

时间:2019-01-29 04:11:08

标签: tensorflow

我正在尝试在大型数据集上训练模型。我想在完成一个时期的培训之前多次运行评估步骤。看看使用Estimators的Dataset API的实现,好像我每次在评估步骤之后重新开始训练时,Estimator都会从头开始创建一个新的数据集,而训练永远无法完成全部数据。

我写了一个非常类似于tensorflow网站上提供的输入功能。

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), 
    labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.repeat(1).batch(batch_size)

    # Return the read end of the pipeline.
    return dataset

然后我使用tf.estimator.Estimator.train调用我的输入函数。我使用以下方法调用上述输入函数。

classifier.train(input_fn=lambda: train_input_fn, 
steps=n_steps)  

其中n_steps的数量少于完成一个纪元的总步数。

然后我这样调用评估函数。

classifier.evaluate(input_fn=lambda: eval_input_fn())

我希望两个步骤都循环运行。 每次循环进行训练时,都会初始化train_input_fn中的数据集。这仅在训练数据的前n个步骤中应用训练。

1 个答案:

答案 0 :(得分:0)

如果您想在训练期间进行多次评估,可以检查InMemoryEvaluatorHook

您可能可以参考this discussion关于train_and_evaluate和InMemoryEvaluatorHook,以获取有关如何更好地使用它们的更多详细信息。