在自定义类上使用sklearn GridSearchCV,其fit方法需要3个参数

时间:2017-08-30 17:31:13

标签: python numpy machine-learning scikit-learn grid-search

我正在开发一个涉及将一些算法实现为python类并测试其性能的项目。我决定将它们写成sklearn估算器,以便我可以使用GridSearchCV进行验证。

但是,我的Inductive Matrix Completion算法之一不仅仅需要Xy作为参数。这成为GridSearchCV.fit的问题,因为似乎无法将Xy传递给估算器的拟合方法。源显示GridSearchCV.fit的以下参数:

def fit(self, X, y=None, groups=None, **fit_params):

当然,下游方法只需要这两个参数。显然,修改GridSearchCV的本地副本以满足我的需求并非易事(或可取)。

作为参考,IMC基本上表示$ R \约XW ^ THY ^ T $。所以我的拟合方法采用以下形式:

def fit(self, R, X, Y):

因此,尝试以下操作失败,因为Y值永远不会传递给IMC.fit方法:

imc = IMC()
params = {...}
gs = GridSearchCV(imc, param_grid=params)
gs.fit(R, X, Y)

我已经通过修改IMC.fit方法为此创建了一个解决方法(这也必须插入到score方法中):

def fit(self, R, X, Y=None):
    if Y is None:
        split = np.where(np.all(X == 999, axis=0))[0][0]
        Y = X[:, split + 1:]
        X = X[:, :split]
    ...

这允许我使用numpy.hstack水平堆叠X和Y,并在它们之间插入所有999的列。然后可以将此数组传递给GridSearchCV.fit,如下所示:

data = np.hstack([X, np.ones((X.shape[0],1)) * 999, Y])
gs.fit(R, data)

这种方法有效,但感觉相当hacky。因此,我的问题是:

是否有一种普遍接受的方式或最佳做法,使用GridSearchCV将超过2个参数传递给fit方法?

1 个答案:

答案 0 :(得分:1)

所以在从朋友那里得到一些灵感(@Matthew Drury)后,我构建了一个更优雅的解决方案。

问题再次出现:

我有一个矩阵完成方法,它将XYR作为参数,并尝试构建WH来最小化{{ 1}}用于R - XWHY中所有观察到的索引。 R方法的基本实现如下所示:

fit

这不适合标准的sklearn模型,其中拟合需要def fit(X, Y, R): W, H = do_minimization(X, Y, R) return W, H (提供给模型的功能)和X(结果),如下所示:

y

在您开始使用def fit(X, y): W, H = do_minimization(X, y) return W, H 或其他交叉验证方法之前,这不是真正的问题,因为他们希望数据符合后一种格式。因此,为了将这两个概念结合起来,我需要一种方法将两个不同的矩阵GridSearchCVX打包成一个结构,而不会失去两者的独立性。

在5分钟内,我不得不专心致志于此,我想出了hacky解决方案。在矩阵Y形状R中,其中行对应于n, m中的记录,而列对应于X中的记录,总共有Y个条目。如果我们对行中的所有条目和索引b以及列上的X获取行和列索引,我们将得到Y和{{1}的等长矩阵}}。然后可以将它们水平堆叠,由一列废话分隔,并传递给交叉验证方法而不会出现问题(我们只需要在原始类中使用几个辅助方法来重建原始XY在装配之前从堆栈中。

这个问题的关键是找到优雅的解决方案,或者最好是现有的解决方案。情况似乎并非如此,因此我将为未来从sklearn继承的估算器/分类器提出以下模型,该估算器/分类器需要的不仅仅是fit方法的单个特征矩阵。

创建DataHandler

使用X时,Y方法会在启动对估算工具GridSearchCV方法的任何调用之前执行一轮检查。其中一个确定传递的fit数组是否为indexable。此测试基本上检查fit是否实现XX,并且与__getitem__的长度相同。此长度检查要求iloc具有y属性。此时,可以按预期计算拆分索引和拟合。因此,我们需要一个实现X且具有shape属性的包装器。

__getitem__

多数民众赞成!我们现在可以修改shape方法以匹配sklearn样式,但在这种情况下,而不是class DataHandler(object): def __init(self, X, Y): self.X = X self.Y = Y self.shape = self.X.shape def __getitem__(self, x): return self.X[x], self.Y[x] 是一个数组,它将是一个元组(fit方法返回的结果)或我们X类的实例。

现在__getitem__只需传递包含DataHandlerGridSearchCV数组的DataHandler实例即可按预期工作。