我正在编写一些代码,针对某些数据评估不同的sklearn模型。我正在使用类型提示,既用于我自己的学习,又用于帮助最终将不得不阅读我的代码的其他人。
我的问题是如何指定sklearn预测变量的类型(例如LinearRegression()
)?
例如:
def model_tester(model : Predictor,
parameter: int
) -> np.ndarray:
"""An example function with type hints."""
# do stuff to model
return values
我看到typing library可以创建新的类型,也可以使用TypeVar
进行操作:
Predictor = TypeVar('Predictor')
但如果已有sklearn模型的常规类型,我就不想使用它。
检查LinearRegression()的类型会产生:
sklearn.linear_model.base.LinearRegression
这显然是有用的,但是仅当我对LinearRegression模型感兴趣时。
答案 0 :(得分:7)
从Python 3.8开始(或更早使用typing-extensions),您可以使用typing.Protocol
。使用协议,您可以使用名为structural subtyping的概念来精确定义类型的预期结构:
from typing import Protocol
# from typing_extensions import Protocol # for Python <3.8
class ScikitModel(Protocol):
def fit(self, X, y, sample_weight=None): ...
def predict(self, X): ...
def score(self, X, y, sample_weight=None): ...
def set_params(self, **params): ...
然后您可以将其用作类型提示:
def do_stuff(model: ScikitModel) -> Any:
model.fit(train_data, train_labels) # this type checks
score = model.score(test_data, test_labels) # this type checks
...
答案 1 :(得分:5)
一个好的解决方法是创建自己的自定义类型提示类(使用Union),该类包括您常用的所有模型。它需要更多的努力,但可以使您变得具体并且可以与PyCharm一起使用。
ModelRegressor = Union[LinearRegression, DecisionTreeRegressor, RandomForestRegressor, SVR]
def foo(model: ModelRegressor):
do_something
答案 2 :(得分:3)
我认为所有模型继承的最通用的类是sklearn.base.BaseEstimator
。
如果您想更具体一点,可以使用sklearn.base.ClassifierMixin
或sklearn.base.RegressorMixin
。
所以我会这样做:
from sklearn.base import RegressorMixin
def model_tester(model: RegressorMixin, parameter: int) -> np.ndarray:
"""An example function with type hints."""
# do stuff to model
return values
我不是类型检查方面的专家,如果这不正确,请纠正我。