传递DecisionTreeClassifier的参数时出错

时间:2017-08-26 19:15:33

标签: python scikit-learn decision-tree

我正在尝试使用字符串中的参数的DecisionTreeClassifier。

 print d    # d= 'max_depth=100'
 clf = DecisionTreeClassifier(d)
 clf.fit(X[:3000,], labels[:3000])

对于这种情况,我收到了以下错误。如果我使用clf = DecisionTreeClassifier(max_depth=100),它可以正常工作。

Traceback (most recent call last):
  File "train.py", line 120, in <module>
    grid_search_generalized(X, labels, {"max_depth":[i for i in range(100, 200)]})
  File "train.py", line 51, in grid_search_generalized
    clf.fit(X[:3000,], labels[:3000])
  File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 790, in fit
    X_idx_sorted=X_idx_sorted)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 326, in fit
    criterion = CRITERIA_CLF[self.criterion](self.n_outputs_,
KeyError: 'max_depth=100'

2 个答案:

答案 0 :(得分:1)

您将参数作为字符串对象传递,而不是作为可选参数传递 如果你真的 用这个字符串调用构造函数,你可以使用这段代码:

 arg = dict([d.split("=")])
 clf = DecisionTreeClassifier(**arg)

您可以在此链接中阅读有关参数解包的更多信息,请Passing a dictionary to a function in python as keyword parameters

答案 1 :(得分:1)

关键字变量自变量尚未在DecisionTreeClassifier函数中定义。可以将max_depth作为关键字参数传递。请尝试以下代码:

d= 'max_depth=100'
arg = dict([d.split("=")])
i = int(next(iter(arg.values())))
k = next(iter(arg.keys()))
clf = DecisionTreeClassifier(max_depth=args['max_depth'])
clf.fit(X[:3000,], labels[:3000])

输出:

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=100,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')