在Tensorflow中保存并恢复Experimenter / Estimator

时间:2017-07-26 08:29:51

标签: python tensorflow

我正在编写一个程序,训练4个神经网络并收集其结果以找到更好的结果。我将使用一个集合方法来做到这一点,但问题不在于它。

问题在于在培训过程之后恢复每个模型。 我发现了another question,但它根本没有帮助。

(伪)码

我的英语不是很好,所以我会尝试使用伪python代码解释我的工作流程:

for i in range(4):
    # Create the estimator (a DNNClassifier).
    estimator = build_estimator(...)
    # Train the model.
    estimator.fit(input_fn=...)

# Do other stuffs...

for i in range(4):
    # Restore the estimator using the same arguments.
    estimator = build_estimator(...)
    # Predict the input data.
    predictions[i] = estimator.predict(input_fn=...)

    # Do others stuffs using the predictions collection.

错误

从广义上讲,这是我的代码,即使它看起来很好也很简单,但它并不起作用。在恢复部分期间显示此错误,这意味着我的DNN未正确保存。

2017-07-24 11:40:24.517773: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/hiddenlayer_1/weights not found in checkpoint
2017-07-24 11:40:24.517884: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/hiddenlayer_0/biases not found in checkpoint
2017-07-24 11:40:24.518739: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/hiddenlayer_0/weights not found in checkpoint
2017-07-24 11:40:24.519621: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/logits/biases not found in checkpoint
2017-07-24 11:40:24.519684: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/hiddenlayer_1/biases not found in checkpoint
2017-07-24 11:40:24.519861: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/hiddenlayer_2/weights not found in checkpoint
2017-07-24 11:40:24.519947: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/hiddenlayer_2/biases not found in checkpoint
2017-07-24 11:40:24.522592: W c:\tf_jenkins\home\workspace\release-win\m\windows\py\35\tensorflow\core\framework\op_kernel.cc:1158] Not found: Key dnn/logits/weights not found in checkpoint

注意

  1. 已使用相同的参数创建和恢复DNNC分类器。
  2. 据我所知,我不需要保存任何检查点,因为DNNClassifier在培训过程中会这样做。
  3. 我在某处读过,在使用predict之前需要evaluation。我试过了,但没有改变。
  4. 如果您愿意,我可以分享我的代码的其他片段,但我认为这不会对您有所帮助。

1 个答案:

答案 0 :(得分:-1)

你应该创建这样的模型:

estimator = tf.contrib.learn.DNNClassifier(feature_columns = feature_cols,hidden_​​units = [10,10,10],n_classes = 2,model_dir = your_path)

estimator.fit(input_fn = lambda:input_fn1(训练数据),步数= 1000)

estimator_predict = estimator.predict(input_fn = lambda:input_fn1(test_data))