如何使GridSeachCV在我的管道中使用自定义变换器?

时间:2015-06-22 20:14:47

标签: python pandas scikit-learn

如果我排除我的自定义变换器GridSearchCV运行正常,但有错误。 这是一个虚假的数据集:

.factory('User', function (Organisation) {

  /**
   * Constructor, with class name
   */
  function User(firstName, lastName, role, organisation) {
    // Public properties, assigned to the instance ('this')
    this.firstName = firstName;
    this.lastName = lastName;
    this.role = role;
    this.organisation = organisation;
  }

  /**
   * Public method, assigned to prototype
   */
  User.prototype.getFullName = function () {
    return this.firstName + ' ' + this.lastName;
  };

  /**
   * Private property
   */
  var possibleRoles = ['admin', 'editor', 'guest'];

  /**
   * Private function
   */
  function checkRole(role) {
    return possibleRoles.indexOf(role) !== -1;
  }    

  /**
   * Static property
   * Using copy to prevent modifications to private property
   */
  User.possibleRoles = angular.copy(possibleRoles);

  /**
   * Static method, assigned to class
   * Instance ('this') is not available in static context
   */
  User.build = function (data) {
    if (!checkRole(data.role)) {
      return;
    }
    return new User(
      data.first_name,
      data.last_name,
      data.role,
      Organisation.build(data.organisation) // another model
    );
  };

  /**
   * Return the constructor function
   */
  return User;
})

,错误是

import pandas
import numpy
from sklearn_pandas import DataFrameMapper
from sklearn_pandas import cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.base import TransformerMixin
from sklearn.preprocessing import LabelBinarizer
from sklearn.ensemble import RandomForestClassifier
import sklearn_pandas
from sklearn.preprocessing import MinMaxScaler

df = pandas.DataFrame({"Letter":["a","b","c","d","a","b","c","d","a","b","c","d","a","b","c","d"],
                       "Number":[1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4], 
                       "Label":["G","G","B","B","G","G","B","B","G","G","B","B","G","G","B","B"]})

class MyTransformer(TransformerMixin):

    def transform(self, x, **transform_args):
        x["Number"] = x["Number"].apply(lambda row: row*2)
        return x

    def fit(self, x, y=None, **fit_args):
        return self

x_train = df
y_train = x_train.pop("Label")    

mapper = DataFrameMapper([
    ("Number", MinMaxScaler()),
    ("Letter", LabelBinarizer()),
    ])

pipe = Pipeline([
    ("custom", MyTransformer()),
    ("mapper", mapper),
    ("classifier", RandomForestClassifier()),
    ])


param_grid = {"classifier__min_samples_split":[10,20], "classifier__n_estimators":[2,3,4]}

model_grid = sklearn_pandas.GridSearchCV(pipe, param_grid, verbose=2, scoring="accuracy")

model_grid.fit(x_train, y_train)

如果我的管道中有自定义变换器,如何使GridSearchCV工作?

2 个答案:

答案 0 :(得分:0)

简短版本:pandas和scikit-learn的交叉验证方法不喜欢以这种方式交谈(在我的版本中,0.15);这可以通过简单地将scikit-learn更新为0.16 / stable或0.17 / dev来解决。

GridSearchCV类验证数据并将其转换为数组(以便它可以正确执行CV拆分)。因此,您无法在内置交叉验证循环中使用Pandas DataFrame功能。

如果你想做这样的事情,你必须制作自己不进行验证的交叉验证程序。

编辑:这是我使用scikit-learn的交叉验证程序的经验。这就是sklearn-pandas提供cross_val_score的原因。但是,据我所知,GridSearchCV并不是专门用于sklearn-pandas;导入它时会意外导入默认的sklearn版本。因此,您可能必须使用ParameterGrid和sklearn-pandas的cross_val_score实现自己的网格搜索。

答案 1 :(得分:0)

我知道这个答案来得太晚了,但是我遇到了与sklearn和BaseSearchCV派生类相同的行为。问题实际上似乎源于sklearn cross_validation模块中的_PartitionIterator类,因为它假设从管道中的每个TransformerMixin类发出的所有内容都将类似于数组,因此它生成索引片,用于以类似数组的方式索引传入的X args。这是__iter__方法:

def __iter__(self):
    ind = np.arange(self.n)
    for test_index in self._iter_test_masks():
        train_index = np.logical_not(test_index)
        train_index = ind[train_index]
        test_index = ind[test_index]
        yield train_index, test_index 

BaseSearchCV网格搜索元类调用cross_validation的_fit_and_score,它使用名为safe_split的方法。这是相关的一行:

X_subset = [X[idx] for idx in indices]

如果X是您从transform函数发出的pandas数据帧,这绝对会产生意外结果。

我发现有两种方法可以解决这个问题:

  1. 确保从变压器返回一个数组:

    return x.as_matrix()
    
  2. 这是一个黑客。如果变换器管道要求输入到下一个变换器是一个DataFrame,就像我的情况一样,你可以编写一个与sklearn grid_search模块基本相同的实用程序脚本,但是包含一些聪明的验证方法。在_fit类的BaseSearchCV方法中调用:

    def _validate_X(X):
        """Returns X if X isn't a pandas frame, otherwise 
        the underlying matrix in the frame. """
        return X if not isinstance(X, pd.DataFrame) else X.as_matrix()
    
    def _validate_y(y):
        """Returns y if y isn't a series, otherwise the array"""
        if y is None:
            return y
    
        # if it's a series
        elif isinstance(y, pd.Series):
            return np.array(y.tolist())
    
        # if it's a dataframe:
        elif isinstance(y, pd.DataFrame):
            # check it's X dims
            if y.shape[1] > 1:
                raise ValueError('matrix provided as y')
            return y[y.columns[0]].tolist()
    
        # bail and let the sklearn function handle validation
        return y
    
  3. 例如,here's my "custom grid_search module"