使 ML 模型与 scikit-learn 兼容

时间:2021-05-09 00:03:35

标签: python scikit-learn

我想让这个 ML 模型与 scikit-learn 兼容: https://github.com/manifoldai/merf

为此,我按照此处的说明进行操作:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/ 并导入 import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware # Import A module from my own project that has the routes defined from redorg.routers import saved_items origins = [ 'http://localhost:8080', ] webapp = FastAPI() webapp.include_router(saved_items.router) webapp.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) def serve(): """Serve the web application.""" uvicorn.run(webapp) if __name__ == "__main__": serve() 并从它们继承而来,如下所示: from sklearn.base import BaseEstimator, RegressorMixin

但是,当我检查 scikit-learn 兼容性时:

class MERF(BaseEstimator, RegressorMixin):

我收到此错误:

from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

如何使这个模型与 scikit-learn 兼容?

1 个答案:

答案 0 :(得分:6)

docs 中,check_estimator 用于“检查估算器是否符合 scikit-learn 约定。”

<块引用>

此估算器将运行广泛的测试套件以进行输入验证、形状等,确保估算器符合滚动您自己的估算器中详述的 scikit-learn 约定。如果 Estimator 类继承自 sklearn.base 的相应 mixin,则将运行针对分类器、回归器、聚类或转换器的其他测试。

所以 check_estimator 不仅仅是兼容性检查,它还检查您是否遵循所有约定等。

您可以阅读 rolling your own estimator 以确保遵守约定。

然后你需要传递你的 estimator 类的一个实例来检查像 check_estimator(MERF()) 这样的 esimator。要真正使其遵循所有约定,您必须解决它抛出的每个错误并一一修复。

例如,其中一项检查是 __init__ 方法只设置它作为参数接受的那些属性。

MERF 类违反了:

    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

它正在设置诸如 self.b_hat_history 之类的属性,即使它们不是参数。

还有很多其他类似的检查。

我个人的建议是,除非必要,否则不要检查所有这些条件,只需继承 Mixins 和 Base 类,实现所需的方法并使用模型。

相关问题