我正在学习sklearn
,我写了一个课程Classifier
来做共同的分类。它需要method
来确定使用哪个Estimator:
# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
class Classifier(object):
def __init__(self, method='LinearSVC', *args, **kwargs):
Estimator = getattr(**xxx**, method, None)
self.Estimator = Estimator
self._model = Estimator(*args, **kwargs)
def fit(self, data, target):
return self._model.fit(data, target)
def predict(self, data):
return self._model.predict(data)
def score(self, X, y, sample_weight=None):
return self._model.score(X, y, sample_weight=None)
def persist_model(self):
pass
def get_model(self):
return self._model
def classification_report(self, expected, predicted):
return metrics.classification_report(expected, predicted)
def confusion_matrix(self, expected, predicted):
return metrics.confusion_matrix(expected, predicted)
我想通过名字获得Estimator,但 xxx 应该是什么?
或者有更好的方法吗?
构建一个dict来存储导入的模块?但这种方式似乎不太好..
答案 0 :(得分:1)
在这种情况下,建议直接使用该类作为参数。
您永远不必担心它是一个字符串:您可以比较LinearSVC is LinearSVC
,并将其与其他内容进行比较。
将其视为接受一个整数作为参数,然后将其转换为字符串以使用它:这有意义吗?你可以只需要一个字符串。
建议代码:
class Classifier(object):
def __init__(self, model = LinearSVC, *args, **kwargs):
self._model = model(*args, **kwargs)
然后你可以这样做:
myclf = Classifier(..., estimator = LinearSVC, ...)
isinstance(myclf._model, LinearSVC)
然后你也可以在开始时初始化一个dict:
from sklearn.svm import LinearSVC
str_to_model = {'LinearSVC' : LinearSVC}
class Classifier(object):
def __init__(self, model = "LinearSVC", *args, **kwargs):
self._model = str_to_model[model](*args, **kwargs)
使用KeyError
更加清晰(字符串/模型不存在,并且您已经知道,因为您没有定义它们),而不是检查globals
,听起来很讨厌!
答案 1 :(得分:0)
内置函数globals()可以解决问题:您可以检查True
是否返回globals()[method]
。
附录
some_method_dict[method]
globals()[method]
globals()
只是问题的最短答案。如果这是pythonic或不是pythonic,我不知道,但Estimator = getattr(..., method, None)
内置是否可以使用,那么为什么选择更复杂的解决方案呢?
要明确,
Estimator = globals().get(method)
可以实现为
None
如果未导入KeyError
,则method
返回首选@POST("/my/url/path")
Result postToServer(
@Query("user_name") String userName);
例外。
答案 2 :(得分:0)
有两个内置函数可以帮助您:globals
和locals
,两者都返回当前符号表的dict。
您的代码可能是Estimator = globals()[method]
或mv到[{1}}的估算键并使用__init__