sklearn模型是什么类型?

时间:2019-02-25 14:46:12

标签: python scikit-learn

我正在编写一些代码,针对某些数据评估不同的sklearn模型。我正在使用类型提示,既用于我自己的学习,又用于帮助最终将不得不阅读我的代码的其他人。

我的问题是如何指定sklearn预测变量的类型(例如LinearRegression())?

例如:

def model_tester(model : Predictor,
                 parameter: int
                 ) -> np.ndarray:
     """An example function with type hints."""

     # do stuff to model 

     return values

我看到typing library可以创建新的类型,也可以使用TypeVar进行操作:

Predictor = TypeVar('Predictor') 

但如果已有sklearn模型的常规类型,我就不想使用它。

检查LinearRegression()的类型会产生:

 sklearn.linear_model.base.LinearRegression

这显然是有用的,但是仅当我对LinearRegression模型感兴趣时。

3 个答案:

答案 0 :(得分:7)

从Python 3.8开始(或更早使用typing-extensions),您可以使用typing.Protocol。使用协议,您可以使用名为structural subtyping的概念来精确定义类型的预期结构:

from typing import Protocol
# from typing_extensions import Protocol  # for Python <3.8

class ScikitModel(Protocol):
    def fit(self, X, y, sample_weight=None): ...
    def predict(self, X): ...
    def score(self, X, y, sample_weight=None): ...
    def set_params(self, **params): ...

然后您可以将其用作类型提示:

def do_stuff(model: ScikitModel) -> Any:
    model.fit(train_data, train_labels)  # this type checks 
    score = model.score(test_data, test_labels)  # this type checks
    ...

答案 1 :(得分:5)

一个好的解决方法是创建自己的自定义类型提示类(使用Union),该类包括您常用的所有模型。它需要更多的努力,但可以使您变得具体并且可以与PyCharm一起使用。

ModelRegressor = Union[LinearRegression, DecisionTreeRegressor, RandomForestRegressor, SVR]

def foo(model: ModelRegressor):
    do_something

答案 2 :(得分:3)

我认为所有模型继承的最通用的类​​是sklearn.base.BaseEstimator

如果您想更具体一点,可以使用sklearn.base.ClassifierMixinsklearn.base.RegressorMixin

所以我会这样做:

from sklearn.base import RegressorMixin


def model_tester(model: RegressorMixin, parameter: int) -> np.ndarray:
     """An example function with type hints."""

     # do stuff to model 

     return values

我不是类型检查方面的专家,如果这不正确,请纠正我。