如何使FunctionTransformer和GridSearchCV一起进入管道?

时间:2019-08-21 03:17:36

标签: python machine-learning scikit-learn

基本上,我想将列索引视为超参数。然后,将此超参数与管道中的其他模型超参数一起调整。在下面的示例中,col_idx是我的超参数。我自定义了一个名为log_columns的函数,该函数可以对某些列执行日志转换,并且可以将该函数传递到FunctionTransformer中。然后将FunctionTransformer和模型放入管道中。

from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.datasets import load_digits
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import FunctionTransformer


def log_columns(X, col_idx = None):
    log_func = np.vectorize(np.log)
    if col_idx is None:
        return X
    for idx in col_idx:
        X[:,idx] = log_func(X[:,idx])
    return X

pipe = make_pipeline(FunctionTransformer(log_columns, ), PCA(), SVC())
param_grid = dict(functiontransformer__col_idx = [None, [1]],
              pca__n_components=[2, 5, 10],
              svc__C=[0.1, 10, 100],
              )

grid_search = GridSearchCV(pipe, param_grid=param_grid)
digits = load_digits()

res = grid_search.fit(digits.data, digits.target)

然后,我收到以下错误消息:

ValueError: Invalid parameter col_idx for estimator 
FunctionTransformer(accept_sparse=False, check_inverse=True,
      func=<function log_columns at 0x1764998c8>, inv_kw_args=None,
      inverse_func=None, kw_args=None, pass_y='deprecated',
      validate=None). Check the list of available parameters with 
`estimator.get_params().keys()`.

我不确定FunctionTransformer是否允许我执行预期的工作。如果不是,我很想知道其他优雅的方法。谢谢!

2 个答案:

答案 0 :(得分:2)

col_idx不是FunctionTransformer类的有效参数,而kw_args是有效的参数。 kw_argsfunc的其他关键字参数的字典。就你而言 唯一的关键字参数是col_idx

尝试一下:

param_grid = dict(
    functiontransformer__kw_args=[
        {'col_idx': None},
        {'col_idx': [1]}
    ],
    pca__n_components=[2, 5, 10],
    svc__C=[0.1, 10, 100],
)

答案 1 :(得分:1)

首先,您应该检查可以调整的参数:pipe.get_params().keys()

之后,请查看documentation,了解如何组织param_grid