与TensorFlowDNNClassifier相比,DNNClassifier是否不稳定?

时间:2016-07-16 16:26:22

标签: tensorflow deep-learning skflow

我使用TF v0.9建立基于skflow的DNN预测(0或1)模型。 我TensorFlowDNNClassifier的代码是这样的。我培训了大约26,000条记录并测试了6,500条记录。

classifier = learn.TensorFlowDNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)
test_pred = classifier.predict(test_features)
print(classification_report(test_labels, test_pred))

大约需要1分钟才能得到结果。

             precision    recall  f1-score   support
          0       0.77      0.92      0.84      4265
          1       0.75      0.47      0.58      2231
avg / total       0.76      0.76      0.75      6496

但我得到了

WARNING:tensorflow:TensorFlowDNNClassifier class is deprecated. 
Please consider using DNNClassifier as an alternative.

所以我简单地用DNNClassifier更新了我的代码。

classifier = learn.DNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)

它也很好用。但结果并不相同。

             precision    recall  f1-score   support
          0       0.77      0.96      0.86      4265
          1       0.86      0.45      0.59      2231
avg / total       0.80      0.79      0.76      6496

1的精确度得到提升。 当然这对我有好处,但为什么会有所改善? 大约需要2个小时。 这比前一个例子慢了约120倍

我有什么不对吗?还是错过了一些参数? 或DNNClassifier与TF v0.9不稳定?

1 个答案:

答案 0 :(得分:0)

我给出与here相同的答案。您可能会遇到这种情况,因为您使用了steps参数而不是 max_steps 。这只是TensorFlowDNNClassifier的一个步骤,实际上是max_steps。现在你可以决定你是否真的希望在你的情况下50000步或先前自动中止。