按名称获取当前文件中已导入的模块

时间:2015-06-09 07:58:15

标签: python python-import

我正在学习sklearn,我写了一个课程Classifier来做共同的分类。它需要method来确定使用哪个Estimator:

# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier

class Classifier(object):
    def __init__(self, method='LinearSVC', *args, **kwargs):
        Estimator = getattr(**xxx**, method, None)
        self.Estimator = Estimator
        self._model = Estimator(*args, **kwargs)

    def fit(self, data, target):
        return self._model.fit(data, target)

    def predict(self, data):
        return self._model.predict(data)

    def score(self, X, y, sample_weight=None):
        return self._model.score(X, y, sample_weight=None)

    def persist_model(self):
        pass

    def get_model(self):
        return self._model

    def classification_report(self, expected, predicted):
        return metrics.classification_report(expected, predicted)

    def confusion_matrix(self, expected, predicted):
        return metrics.confusion_matrix(expected, predicted)

我想通过名字获得Estimator,但 xxx 应该是什么? 或者有更好的方法吗?
构建一个dict来存储导入的模块?但这种方式似乎不太好..

3 个答案:

答案 0 :(得分:1)

在这种情况下,建议直接使用该类作为参数。

您永远不必担心它是一个字符串:您可以比较LinearSVC is LinearSVC,并将其与其他内容进行比较。

将其视为接受一个整数作为参数,然后将其转换为字符串以使用它:这有意义吗?你可以只需要一个字符串。

建议代码:

class Classifier(object):
    def __init__(self, model = LinearSVC, *args, **kwargs):
        self._model = model(*args, **kwargs)

然后你可以这样做:

myclf = Classifier(..., estimator = LinearSVC, ...)
isinstance(myclf._model, LinearSVC)

根据评论:

然后你也可以在开始时初始化一个dict:

from sklearn.svm import LinearSVC

str_to_model = {'LinearSVC' : LinearSVC}

class Classifier(object):
    def __init__(self, model = "LinearSVC", *args, **kwargs):
        self._model = str_to_model[model](*args, **kwargs)

使用KeyError更加清晰(字符串/模型不存在,并且您已经知道,因为您没有定义它们),而不是检查globals ,听起来很讨厌!

答案 1 :(得分:0)

内置函数globals()可以解决问题:您可以检查True是否返回globals()[method]

附录

  • 安全:没有什么'讨厌'可以通过评估some_method_dict[method]
  • 来实现
  • 效率:相对于globals()[method]
  • ,开销可以忽略不计
  • 简单:globals()只是问题的最短答案。

如果这是pythonic或不是pythonic,我不知道,但Estimator = getattr(..., method, None) 内置是否可以使用,那么为什么选择更复杂的解决方案呢?

要明确,

Estimator = globals().get(method)

可以实现为

None

如果未导入KeyError,则method返回首选@POST("/my/url/path") Result postToServer( @Query("user_name") String userName); 例外。

答案 2 :(得分:0)

有两个内置函数可以帮助您:globalslocals,两者都返回当前符号表的dict。

您的代码可能是Estimator = globals()[method]或mv到[{1}}的估算键并使用__init__