我正在开发一个涉及将一些算法实现为python类并测试其性能的项目。我决定将它们写成sklearn估算器,以便我可以使用GridSearchCV
进行验证。
但是,我的Inductive Matrix Completion算法之一不仅仅需要X
和y
作为参数。这成为GridSearchCV.fit
的问题,因为似乎无法将X
和y
传递给估算器的拟合方法。源显示GridSearchCV.fit
的以下参数:
def fit(self, X, y=None, groups=None, **fit_params):
当然,下游方法只需要这两个参数。显然,修改GridSearchCV
的本地副本以满足我的需求并非易事(或可取)。
作为参考,IMC基本上表示$ R \约XW ^ THY ^ T $。所以我的拟合方法采用以下形式:
def fit(self, R, X, Y):
因此,尝试以下操作失败,因为Y值永远不会传递给IMC.fit
方法:
imc = IMC()
params = {...}
gs = GridSearchCV(imc, param_grid=params)
gs.fit(R, X, Y)
我已经通过修改IMC.fit
方法为此创建了一个解决方法(这也必须插入到score
方法中):
def fit(self, R, X, Y=None):
if Y is None:
split = np.where(np.all(X == 999, axis=0))[0][0]
Y = X[:, split + 1:]
X = X[:, :split]
...
这允许我使用numpy.hstack
水平堆叠X和Y,并在它们之间插入所有999
的列。然后可以将此数组传递给GridSearchCV.fit
,如下所示:
data = np.hstack([X, np.ones((X.shape[0],1)) * 999, Y])
gs.fit(R, data)
这种方法有效,但感觉相当hacky。因此,我的问题是:
GridSearchCV
将超过2个参数传递给fit方法?答案 0 :(得分:1)
所以在从朋友那里得到一些灵感(@Matthew Drury)后,我构建了一个更优雅的解决方案。
问题再次出现:
我有一个矩阵完成方法,它将X
,Y
和R
作为参数,并尝试构建W
和H
来最小化{{ 1}}用于R - XWHY
中所有观察到的索引。 R
方法的基本实现如下所示:
fit
这不适合标准的sklearn模型,其中拟合需要def fit(X, Y, R):
W, H = do_minimization(X, Y, R)
return W, H
(提供给模型的功能)和X
(结果),如下所示:
y
在您开始使用def fit(X, y):
W, H = do_minimization(X, y)
return W, H
或其他交叉验证方法之前,这不是真正的问题,因为他们希望数据符合后一种格式。因此,为了将这两个概念结合起来,我需要一种方法将两个不同的矩阵GridSearchCV
和X
打包成一个结构,而不会失去两者的独立性。
在5分钟内,我不得不专心致志于此,我想出了hacky解决方案。在矩阵Y
形状R
中,其中行对应于n, m
中的记录,而列对应于X
中的记录,总共有Y
个条目。如果我们对行中的所有条目和索引b
以及列上的X
获取行和列索引,我们将得到Y
和{{1}的等长矩阵}}。然后可以将它们水平堆叠,由一列废话分隔,并传递给交叉验证方法而不会出现问题(我们只需要在原始类中使用几个辅助方法来重建原始X
和Y
在装配之前从堆栈中。
这个问题的关键是找到优雅的解决方案,或者最好是现有的解决方案。情况似乎并非如此,因此我将为未来从sklearn继承的估算器/分类器提出以下模型,该估算器/分类器需要的不仅仅是fit方法的单个特征矩阵。
使用X
时,Y
方法会在启动对估算工具GridSearchCV
方法的任何调用之前执行一轮检查。其中一个确定传递的fit
数组是否为indexable。此测试基本上检查fit
是否实现X
或X
,并且与__getitem__
的长度相同。此长度检查要求iloc
具有y
属性。此时,可以按预期计算拆分索引和拟合。因此,我们需要一个实现X
且具有shape
属性的包装器。
__getitem__
多数民众赞成!我们现在可以修改shape
方法以匹配sklearn样式,但在这种情况下,而不是class DataHandler(object):
def __init(self, X, Y):
self.X = X
self.Y = Y
self.shape = self.X.shape
def __getitem__(self, x):
return self.X[x], self.Y[x]
是一个数组,它将是一个元组(fit
方法返回的结果)或我们X
类的实例。
现在__getitem__
只需传递包含DataHandler
和GridSearchCV
数组的DataHandler
实例即可按预期工作。