在SkLearn中打印估算器名称

时间:2019-01-05 06:30:44

标签: python model scikit-learn

在Sklearn中,是否可以打印出估算器的类名?

我尝试使用 name 属性,但这不起作用。

from sklearn.linear_model import LogisticRegression  

def print_estimator_name(estimator):
    print(estimator.__name__)

#Expected Outcome:
print_estimator_name(LogisticRegression())

我希望这会打印出如上所述的分类器名称

2 个答案:

答案 0 :(得分:2)

我认为您正在寻找estimator.__class__.__name__,即:

from sklearn.linear_model import LogisticRegression

def print_estimator_name(estimator):
    print(estimator.__class__.__name__)

#Expected Outcome:
print_estimator_name(LogisticRegression())

答案 1 :(得分:2)

我有另一种方法。获取对象名称,转换为str,使用split(".")获取最重要的子类,最后剥离不想要的字符

str(type(clf)).split(".")[-1][:-2])

这对我来说适用于SKLearn,XGBoost和LightGBM

print("Acc: %0.5f for the %s" % (pred, str(type(clf)).split(".")[-1][:-2]))
Acc: 0.7159443 : DecisionTreeClassifier
Acc: 0.7572368 : RandomForestClassifier
Acc: 0.7548593 : ExtraTreesClassifier
Acc: 0.7416970 : XGBClassifier
Acc: 0.7582540 : LGBMClassifier