tensorflow SKCompat与cross_val_score不兼容

时间:2017-05-11 00:43:32

标签: python-2.7 tensorflow scikit-learn

我正在尝试使用一个tensorflow分类器和一些来自scikit learn的工具,即model_selection.cross_val_score。当我运行以下代码(改编自this example from the tensorflow docs)时,我得到TypeError(请参阅下面的完整回溯)。

据我所知,问题是cross_val_score尝试通过执行相当于estimator.__class__(**estimator.get_params(deep=True))的内容来克隆估算工具。由于某种原因,SKCompat.get_params返回{},类上的init方法需要一个参数(如示例代码所示),因此操作会爆炸。

我做错了吗?或者这是张量流的错误吗?

失败的例子

"""Example of DNNClassifier for Iris plant dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from sklearn import metrics
from sklearn import model_selection

import tensorflow as tf


def main(unused_argv):
  # Load dataset.
  iris = tf.contrib.learn.datasets.load_dataset('iris')

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
      iris.data)
  classifier = tf.contrib.learn.SKCompat(
    tf.contrib.learn.DNNClassifier(
      feature_columns=feature_columns,
      hidden_units=[10, 20, 10],
      n_classes=3
    )
  )

  # Fit and predict.
  scores = model_selection.cross_val_score(classifier, iris.data, iris.target,
          scoring='accuracy')
  print('Accuracy: {0:f}'.format(scores.mean()))


if __name__ == '__main__':
  tf.app.run()

回溯

Traceback (most recent call last):
  File "iris.py", line 49, in <module>
    tf.app.run()
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "iris.py", line 44, in main
    scoring='accuracy')
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 140, in cross_val_score
    for train, test in cv_iter)
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 758, in __call__
    while self.dispatch_one_batch(iterator):
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 603, in dispatch_one_batch
    tasks = BatchedCalls(itertools.islice(iterator, batch_size))
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 127, in __init__
    self.items = list(iterator_slice)
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/model_selection/_validation.py", line 140, in <genexpr>
    for train, test in cv_iter)
  File "/Users/Matt/.virtualenvs/numerai/lib/python2.7/site-packages/sklearn/base.py", line 70, in clone
    new_object = klass(**new_object_params)
TypeError: __init__() takes exactly 2 arguments (1 given)

版本

python:2.7.3

tensorflow:1.1.0

scikit-learn:0.18.1

0 个答案:

没有答案