检查sklearn模型是否已安装的最佳方法是什么?即,是否在实例化之后调用了它的fit()
函数。
答案 0 :(得分:9)
您可以执行以下操作:
from sklearn.exceptions import NotFittedError
for model in models:
try:
model.predict(some_test_data)
except NotFittedError as e:
print(repr(e))
理想情况下,您会根据预期结果检查model.predict
的结果,但是如果您想知道是否所有模型都适合,那么就足够了。
一些评论者建议使用check_is_fitted。我认为check_is_fitted
是internal method。大多数算法会在预测方法中调用check_is_fitted
,如果需要,可能会提升NotFittedError
。直接使用check_is_fitted
的问题在于它是特定于模型的,即您需要知道要根据您的算法检查哪些成员。例如:
╔════════════════╦════════════════════════════════════════════╗
║ Tree models ║ check_is_fitted(self, 'tree_') ║
║ Linear models ║ check_is_fitted(self, 'coefs_') ║
║ KMeans ║ check_is_fitted(self, 'cluster_centers_') ║
║ SVM ║ check_is_fitted(self, 'support_') ║
╚════════════════╩════════════════════════════════════════════╝
等等。所以一般情况下我会建议调用model.predict()
并让特定算法处理检查它是否已经安装的最佳方法。
答案 1 :(得分:4)
我这样做是为了分类器:
def check_fitted(clf):
return hasattr(clf, "classes_")
答案 2 :(得分:2)
这是一种贪婪的方法,但对大多数(如果不是所有型号)都应该没问题。这可能不起作用的唯一时间是模型在适合之前设置以下划线结尾的属性,我很确定会违反scikit-learn惯例,所以这应该没问题。
import inspect
def is_fitted(model):
"""Checks if model object has any attributes ending with an underscore"""
return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )
答案 3 :(得分:0)
直接从scikit学习源代码中获取check_is_fitted
函数(与@ david-marx相似,但是更简单):
def is_fitted(model):
'''
Checks if a scikit-learn estimator/transformer has already been fit.
Parameters
----------
model: scikit-learn estimator (e.g. RandomForestClassifier)
or transformer (e.g. MinMaxScaler) object
Returns
-------
Boolean that indicates if ``model`` has already been fit (True) or not (False).
'''
attrs = [v for v in vars(model)
if v.endswith("_") and not v.startswith("__")]
return len(attrs) != 0