在sklearn估算器上调用set_params()时,super()出错

时间:2017-07-13 21:14:19

标签: python scikit-learn python-3.4

我试图根据配置文件加载和配置scikit-learn估算器。该文件具有估算器类路径和名称以及参数的字典。我的计划是使用pydoc.locate()加载带有默认参数的估计器,然后使用参数的dict在估算器上调用set_params()。但是我收到以下错误:

import pydoc
sgd = pydoc.locate('sklearn.linear_model.SGDClassifier')
print('{} {}'.format(type(sgd), sgd))
p_sgd = {'alpha':.1234}
sgd.set_params(p_sgd)
<class 'abc.ABCMeta'> <class 'sklearn.linear_model.stochastic_gradient.SGDClassifier'>
Traceback (most recent call last):
  File "<input>", line 5, in <module>
  File "/Users/doug/.pyenv/versions/learning-3.4.3/lib/python3.4/site-packages/sklearn/linear_model/stochastic_gradient.py", line 83, in set_params
    super(BaseSGD, self).set_params(*args, **kwargs)
TypeError: super(type, obj): obj must be an instance or subtype of type

我尝试使用同样的#34;加载并设置&#34;接近两次。第一次,我按名称加载文本矢量化器并设置其参数。文本向量化程序是我基于HashingVectorizer创建的子类。它不会产生此错误,但似乎也没有通过调用set_params()来更改(即参数值保持默认值)。第二次是具有我描述的行为的分类器。

我在使用提供给GridSearchCV的Pipeline中运行它们时,先使用pydoc.locate()加载估算器。这工作得很好。在这种情况下,我使用默认的估计器构造函数构造管道,然后GridSearchCV让Pipeline在遍历参数网格时在每个估算器上调用set_params()。通过Pipeline和GridSearchCV源看,他们使用set_params()被称为set_params(** param_dict)。如果我试试,我会得到一个不同的错误。

import pydoc
sgd = pydoc.locate('sklearn.linear_model.SGDClassifier')
p_sgd = {'alpha':.1234}
sgd.set_params(**p_sgd)
Traceback (most recent call last):
  File "<input>", line 4, in <module>
TypeError: set_params() missing 1 required positional argument: 'self'

最后一点,我已经读过原始错误(TypeError:super(type,obj)...)已被追踪到多次加载模块的问题。事实上我在这些尝试调用之前使用了pydoc.locate()(为了追踪他们的父母并找出谁是矢量化器与分类器)。我可能能够解决这个问题,但是之前仍然会尝试加载这些模块,因为我在循环中运行以根据配置文件训练多个模型。

我正在使用Python 3.4

1 个答案:

答案 0 :(得分:0)

正如user2357112指出的那样,我错误地只加载了类,而不是构造它。我更改了代码以在没有参数的情况下调用返回类的构造函数,然后使用我期望的**参数语法调用set_params(** p_sgd)。

import pydoc
sgd = pydoc.locate('sklearn.linear_model.SGDClassifier')()
p_sgd = {'alpha':.1234}
sgd.set_params(**p_sgd)
sgd
SGDClassifier(alpha=0.1234, average=False, class_weight=None, epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15, learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5, random_state=None, shuffle=True, verbose=0, warm_start=False)