我使用TF v0.9建立基于skflow的DNN预测(0或1)模型。
我TensorFlowDNNClassifier
的代码是这样的。我培训了大约26,000条记录并测试了6,500条记录。
classifier = learn.TensorFlowDNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)
test_pred = classifier.predict(test_features)
print(classification_report(test_labels, test_pred))
大约需要1分钟才能得到结果。
precision recall f1-score support
0 0.77 0.92 0.84 4265
1 0.75 0.47 0.58 2231
avg / total 0.76 0.76 0.75 6496
但我得到了
WARNING:tensorflow:TensorFlowDNNClassifier class is deprecated.
Please consider using DNNClassifier as an alternative.
所以我简单地用DNNClassifier
更新了我的代码。
classifier = learn.DNNClassifier(hidden_units=[64, 128, 64], n_classes=2)
classifier.fit(features, labels, steps=50000)
它也很好用。但结果并不相同。
precision recall f1-score support
0 0.77 0.96 0.86 4265
1 0.86 0.45 0.59 2231
avg / total 0.80 0.79 0.76 6496
1
的精确度得到提升。
当然这对我有好处,但为什么会有所改善?
大约需要2个小时。
这比前一个例子慢了约120倍。
我有什么不对吗?还是错过了一些参数?
或DNNClassifier
与TF v0.9不稳定?
答案 0 :(得分:0)
我给出与here相同的答案。您可能会遇到这种情况,因为您使用了steps参数而不是 max_steps 。这只是TensorFlowDNNClassifier的一个步骤,实际上是max_steps。现在你可以决定你是否真的希望在你的情况下50000步或先前自动中止。