我正在尝试使用一个tensorflow分类器和一些来自scikit learn的工具,即model_selection.cross_val_score
。当我运行以下代码(改编自this example from the tensorflow docs)时,我得到TypeError
(请参阅下面的完整回溯)。
据我所知,问题是cross_val_score
尝试通过执行相当于estimator.__class__(**estimator.get_params(deep=True))
的内容来克隆估算工具。由于某种原因,SKCompat.get_params
返回{}
,类上的init方法需要一个参数(如示例代码所示),因此操作会爆炸。
我做错了吗?或者这是张量流的错误吗?
"""Example of DNNClassifier for Iris plant dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from sklearn import metrics
from sklearn import model_selection
import tensorflow as tf
def main(unused_argv):
# Load dataset.
iris = tf.contrib.learn.datasets.load_dataset('iris')
# Build 3 layer DNN with 10, 20, 10 units respectively.
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
iris.data)
classifier = tf.contrib.learn.SKCompat(
tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3
)
)
# Fit and predict.
scores = model_selection.cross_val_score(classifier, iris.data, iris.target,
scoring='accuracy')
print('Accuracy: {0:f}'.format(scores.mean()))
if __name__ == '__main__':
tf.app.run()
Traceback (most recent call last):
File "iris.py", line 49, in <module>
tf.app.run()
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "iris.py", line 44, in main
scoring='accuracy')
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 140, in cross_val_score
for train, test in cv_iter)
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 758, in __call__
while self.dispatch_one_batch(iterator):
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 603, in dispatch_one_batch
tasks = BatchedCalls(itertools.islice(iterator, batch_size))
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 127, in __init__
self.items = list(iterator_slice)
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 140, in <genexpr>
for train, test in cv_iter)
File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/base.py", line 70, in clone
new_object = klass(**new_object_params)
TypeError: __init__() takes exactly 2 arguments (1 given)
python:2.7.3
tensorflow:1.1.0
scikit-learn:0.18.1