如何检测参数网格中允许的值? (sklearn)

时间:2017-06-30 08:29:55

标签: python scikit-learn parameter-passing cross-validation

我已经开始研究一个项目,其中我需要检测给定scikit-learn估算器的可训练参数,如果可能的话,找到分类变量的允许值(以及连续值的合理间隔) )。

我可以使用estimator.get_params()获取带参数的字典,然后使用estimator.set_params(**{'var1':val1, 'var2':val2})设置值,依此类推。

例如,对于KNN分类器,我们有以下params的词: {'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}

现在,我可以使用值的类型来推断哪些是分类(str类型),连续(float),离散(int)等等。一个可能相关的问题是默认设置为NoneType的参数,但我可能不会触及这些,原因很充分。

现在的挑战是推断和定义参数网格以用于例如RandomizedSearchCV。对于离散和连续变量,问题可以使用例如try - except块与scipy.stats模块的组合,可能会限制区间位于"附近区域。围绕默认值(但同时注意不要将n_jobs设置为某个疯狂的值 - 可能需要硬编码,或稍后明确设置)。如果你有类似的经历,并有一些提示/技巧,我很乐意听到他们。

但现在真正的问题是:如何推断例如algorithm允许值实际为{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’} ??

我刚开始研究这个问题,如果我们尝试将其设置为某个不允许的值,我们可能会解析出的错误消息?我在这里寻找好主意,因为我想避免手动这样做(如果必须,我会这样做,但看起来相当不优雅......)

谢谢!

2 个答案:

答案 0 :(得分:0)

我找到了一个我正在研究的特定示例的解决方案,但是,它并没有很好地概括为其他文档字符串,因为没有设置约定以及如何为sklearn中的每个估算器编写它们。

因此,我发布了我的“解决方案”,以便其他人可以接管并可能改进它。请参阅以下代码段:

import re
from pprint import pprint 
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
doc = knn.__doc__ # Get the doc string
#from sklearn.svm import SVC
#svc = SVC()
#doc = svc.__doc__
pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern
re.compile(pattern)
matches = re.findall(pattern, doc)

clf_params = {}
previous_param = ''
for param, _, value in matches:
    if ":" in param and param[-4]!="_": # 'Hack-y'
        if param not in clf_params.keys():
            clf_params[param] = list()
            previous_param = param
        else:
            if len(value)>0:
                clf_params[previous_param].append(value)
pprint(clf_params)

此代码段打印

{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'],
 'leaf_size : ': [],
 'metric : ': [],
 'metric_params : ': [],
 'n_jobs : ': [],
 'n_neighbors : ': [],
 'p : ': [],
 'weights : ': ['uniform', 'distance']}

哪个是正确的。

但是,如果我们对SVC().__doc__重复相同的过程,我们会发现它失败了。

我希望某人认为这有点用处。

答案 1 :(得分:0)

我尝试从文档字符串(以LinearSVC作为示例算法)中获取所有这些信息,liner = str(LinearSVC().__doc__).split('Parameters\n ----------\n')[1].split('\n\n Attributes\n')[0].replace('\n ', '\n').splitlines() 对此提供了极大的帮助:

for i in liner:
   ...:     if " : " in i: #<<< the key is to use " : " as our anchor
   ...:         print(i)

这不会创建字典,但足够简单,仅从文档字符串中提取解释的“参数”部分,其中解释了所有参数,并列出了所有可能的/期望的/可接受的值输入,这很好用一个选项卡缩进,现在我们可以使用带条件的简单循环,以“:”为锚点来标识可能/期望/接受的值输入行

    penalty : str, 'l1' or 'l2' (default='l2')
    loss : str, 'hinge' or 'squared_hinge' (default='squared_hinge')
    dual : bool, (default=True)
    tol : float, optional (default=1e-4)
    C : float, optional (default=1.0)
    multi_class : str, 'ovr' or 'crammer_singer' (default='ovr')
    fit_intercept : bool, optional (default=True)
    intercept_scaling : float, optional (default=1)
    class_weight : {dict, 'balanced'}, optional
    verbose : int, (default=0)
    random_state : int, RandomState instance or None, optional (default=None)
    max_iter : int, (default=1000)

最终结果打印到:

print(str(LinearSVC().__doc__).split('Parameters\n    ----------\n')[1].split('\n\n    Attributes\n')[0].replace('\n        ', '\n'))

很高兴我可以分享,如果其他任何人需要完整的docstring参数打印输出,只需使用:

docstring_short = str([i for i in liner.splitlines() if " : " in i]).replace('["    ', '').replace('    ', ',\n').replace('", "', '').replace('", \'', '').replace("', '", '').replace("', \"", '').replace(']', '')

编辑: 如果不希望将其打印出来-使其成为字符串对象的最佳方法是使用列表推导,但是它需要进行一些难看的替换,因为文档字符串中有 extensive 表示法:

{{1}}