覆盖在sklearn上下文中使用的statsmodels GLM中的predict()

时间:2017-11-21 15:28:13

标签: python scikit-learn override glm statsmodels

在sklearn的上下文中使用statsmodel的Poisson GLM模型,我正在尝试建立一个自己的模型,它继承自GLM,BaseEstimator和RegressorMixin。我的目标是做交叉验证之类的东西。这是我的代码:

import statsmodels.api as sm
from sklearn.base import BaseEstimator, RegressorMixin

class GLM_sklearn(sm.GLM, BaseEstimator, RegressorMixin):
    def __init__(self, X, y, family=sm.families.Poisson()):
        super().__init__(y, X, family=family)

    def fit(self, **kwargs):
        self.results_ = super().fit()

        self.coef_ = self.results_.params.values
        self.bse_ = self.results_.bse.values

        return self

    def predict(self, X, **kwargs):
        return self.results_.predict(X)

fit方法工作正常,但我有覆盖predict()的问题。预测我需要结果实例的预测方法(GLMResultsWrapper)。所以我想覆盖GLM.predict方法(它有另一个功能)。正如在代码中尝试的那样,我得到了预期的错误:

predict_results = self.model.predict(self.params,exog,* args,** kwargs) TypeError:predict()需要2个位置参数,但是3个被赋予

是否有可能“完全”覆盖预测方法?

1 个答案:

答案 0 :(得分:1)

您可能希望GLM_sklearn拥有sm.GLM和RegressorMixin的实例并且只从BaseEstimator继承,而不是继承所有三个可能会导致问题,例如一个父类覆盖其他成员。然后,您可以实现拟合和预测,而不必担心父类的成员。