我在输入数据上运行多标签分类,包含330个功能和大约800条记录。我正在利用RandomForestClassifier和param_grid:
> param_grid = {"n_estimators": [20],
> "max_depth": [6],
> "max_features": [80, 150],
> "min_samples_leaf": [1, 3, 10],
> "bootstrap": [True, False],
> "criterion": ["gini", "entropy"],
> "oob_score": [True, False]}
清理数据后,我就是这样设置分类器并适应模型并应用decision_fucntion:
classifier = OneVsRestClassifier(RandomForestClassifier(param_grid))
y_score = classifier.fit(X_train, y_train).descition_function(X_test)
X_train形状 - (800,334),Y_train形状 - (800,4)。 分类数量 - 4.在sklearn 0.18中运行代码
但是,请运行以下错误消息:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-164-db76d3122db8> in <module>()
1 classifier = OneVsRestClassifier(RandomForestClassifier(param_grid))
----> 2 y_score = classifier.fit(X_train, y_train).descition_function(X_test)
3 #clf = RandomForestClassifier()
4 #gr_search = grid_search.GridSearchCV(clf, param_grid02, cv=10, scoring = 'accuracy')
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/multiclass.py in fit(self, X, y)
214 "not %s" % self.label_binarizer_.classes_[i],
215 self.label_binarizer_.classes_[i]])
--> 216 for i, column in enumerate(columns))
217
218 return self
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in __call__(self, iterable)
756 # was dispatched. In particular this covers the edge
757 # case of Parallel used with an exhausted iterator.
--> 758 while self.dispatch_one_batch(iterator):
759 self._iterating = True
760 else:
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in dispatch_one_batch(self, iterator)
606 return False
607 else:
--> 608 self._dispatch(tasks)
609 return True
610
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in _dispatch(self, batch)
569 dispatch_timestamp = time.time()
570 cb = BatchCompletionCallBack(dispatch_timestamp, len(batch), self)
--> 571 job = self._backend.apply_async(batch, callback=cb)
572 self._jobs.append(job)
573
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py in apply_async(self, func, callback)
107 def apply_async(self, func, callback=None):
108 """Schedule a func to be run"""
--> 109 result = ImmediateResult(func)
110 if callback:
111 callback(result)
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py in __init__(self, batch)
320 # Don't delay the application, to avoid keeping the input
321 # arguments in memory
--> 322 self.results = batch()
323
324 def get(self):
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in __call__(self)
129
130 def __call__(self):
--> 131 return [func(*args, **kwargs) for func, args, kwargs in self.items]
132
133 def __len__(self):
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py in <listcomp>(.0)
129
130 def __call__(self):
--> 131 return [func(*args, **kwargs) for func, args, kwargs in self.items]
132
133 def __len__(self):
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/multiclass.py in _fit_binary(estimator, X, y, classes)
78 else:
79 estimator = clone(estimator)
---> 80 estimator.fit(X, y)
81 return estimator
82
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/ensemble/forest.py in fit(self, X, y, sample_weight)
281
282 # Check parameters
--> 283 self._validate_estimator()
284
285 if not self.bootstrap and self.oob_score:
/Users/ayada/anaconda/lib/python3.5/site-packages/sklearn/ensemble/base.py in _validate_estimator(self, default)
94 """Check the estimator and the n_estimator attribute, set the
95 `base_estimator_` attribute."""
---> 96 if self.n_estimators <= 0:
97 raise ValueError("n_estimators must be greater than zero, "
98 "got {0}.".format(self.n_estimators))
TypeError: unorderable types: dict() <= int()
答案 0 :(得分:1)
为什么要尝试使用参数grid初始化RandomForestClassifier?
如果您想进行网格搜索 - 请查看以下示例: http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV