我正在玩一些scikit学习对象,并且在尝试调整超参数时偶然发现以下结果。这部分代码的输出
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import GridSearchCV
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
class NaiveBayesClassifier(Pipeline):
def __init__(self):
super().__init__(
[("tfidf", TfidfVectorizer()), ("clf", MultinomialNB()),]
)
def tune(self, data, min_df, max_df, max_features):
gs = GridSearchCV(
estimator=self,
param_grid={
"tfidf__max_df": max_df,
"tfidf__min_df": min_df,
"tfidf__max_features": max_features,
},
verbose=10,
)
return gs.fit(*data)
if __name__ == "__main__":
from sklearn.datasets import fetch_20newsgroups
categories = ["alt.atheism", "soc.religion.christian", "comp.graphics", "sci.med"]
twenty_train = fetch_20newsgroups(
subset="train", categories=categories, shuffle=True, random_state=42
)
nb = NaiveBayesClassifier()
tuned_model = nb.tune(
(twenty_train.data, twenty_train.target),
min_df=[0, 0.1],
max_df=[0.9, 1],
max_features=[2_000, 5_000],
)
print(tuned_model.best_score_)
for k, v in tuned_model.best_params_.items():
print(f"{v} <> {tuned_model.best_estimator_.get_params()[k]}")
是以下
0.9437278025233994
0.9 <> 1.0
5000 <> None
0 <> 1
看看网格搜索生成的输出,我可以看到左边的参数确实在5折上产生了平均得分。因此,看来tuned_model.best_params_
是我所期望的。但是,参数
best_estimator_
只是默认设置。
这是什么原因? Pipeline
类具有一个set_params
方法,该方法似乎对tuned_model.best_estimator_.set_params(tuned_model.best_params_)
做了正确的事情(但现在当然只有参数是最优的,而模型不是)。
答案 0 :(得分:0)
子类似乎不适用于sklearn.base.clone
,因为在子类上调用get_params
的结果与在{{1的实际实例上调用的相同方法的结果不同}}(从0.22版开始)。
为了进行适当的sklearn估计,Pipeline
方法必须将参数声明为显式关键字参数。覆盖__init__
有效地创建了没有参数的估计量。这说明了为什么def __init__(self)
返回空的get_params(deep=False)
。
以下子类有效
dict
这现在可以按预期工作,但是class NaiveBayesClassifier(Pipeline):
def __init__(self, steps=[], memory=None, verbose=False):
super().__init__(
steps or [("tfidf", TfidfVectorizer()), ("clf", MultinomialNB())]
)
def tune(self, data, min_df, max_df, max_features):
gs = GridSearchCV(
estimator=self,
param_grid={
"tfidf__max_df": max_df,
"tfidf__min_df": min_df,
"tfidf__max_features": max_features,
},
verbose=10,
)
return gs.fit(*data)
的签名是“难看的”,因为原则上它不应该接受任何参数。