如何仅覆盖sklearn分类器的fit()函数?

时间:2019-11-26 16:46:16

标签: python inheritance scikit-learn

我想在sklearn的库中将SGDClassifier()与gridsearchCV一起使用scoring ='precision'对其分类为1还是0。当我使用kfold 5拆分数据时,有时会遇到某些y_train拆分仅具有全0或全1的情况。 SGDClassifer会吐出一个错误,说我所有的标签只有1个类。我的gridsearchCV()完全停止了。如何正确处理这种情况?我希望SGDClassifer返回精度为0并继续进行gridsearch。我试图编写一个继承SGDClassifer的SGDClassifer2()类,但不确定是否正确执行了此操作。这是我尝试覆盖SGDClassifer的适合度:

class SGD2(SGDClassifier):

    #
    # only override the fit() function here
    #
    def fit(self, X, y):
        # if array y are all 0s or all 1s
        if((sum(y) == 0) || sum(y) == len(y)):
           return 0
        else:
           return self.fit(X,y)

对不起,我是初学者,感谢您的任何建议或帮助。

0 个答案:

没有答案