如何使用Tensorflow估算器

时间:2017-07-06 19:25:17

标签: python tensorflow

关于创建卷积神经网络我正在关注this Tensorflow tutorial

我正在阅读训练和测试数据的步骤:

def main(unused_argv):
  mnist = learn.datasets.load_dataset("mnist")
  train_data = mnist.train.images # Returns np.array
  train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
  eval_data = mnist.test.images # Returns np.array
  eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

到目前为止,一切都很好。

然后突然创建了一个估算器:

mnist_classifier = learn.Estimator(
      model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")

我的问题是:

  1. 什么是Estimator?

  2. 之前的代码不会在"/tmp/mnist_convnet_model"下保存任何内容。为什么在该目录下保存了一个模型? 它是如何实现的?

  3. 编辑:

    当我运行代码时,我得到:

     Couldn't find trained model at ../tmp/mnist_convnet_model. 
    

    这是因为在该目录结构下找不到该模型。

    我如何将模型放在那里?另外,为什么我必须把它放在那里,而不是将它存储在内存中以执行脚本。

1 个答案:

答案 0 :(得分:1)

第一个问题在教程中就已经回答了。 Estimator是一个TensorFlow类,用于执行高级模型培训,评估和推理等。

第二个问题的答案是,不,没有任何东西保存到该目录。估算器对象将使用此目录来保存训练检查点,日志等。第一次运行此代码时,它不会加载任何内容。但是一旦你训练了模型,它就会从那里加载保存的状态。