如何在Python

时间:2015-04-22 15:34:46

标签: python scikit-learn classification text-classification

首先,手头的问题。我正在为scikit-learn类编写一个包装器,并且遇到了正确语法的问题。我想要实现的是覆盖fit_transform函数,它稍微改变输入,然后使用新参数调用其super - 方法:

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):
    def __init__(self):
        TfidfVectorizer.__init__(self)  # is this even necessary?

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return TfidfVectorizer.fit_transform(self, x, y, fit_params)  
                            # this is the critical part, my IDE tells me for
                            # fit_params: 'unexpected arguments'

程序在整个地方崩溃,从Multiprocessing exception开始,并没有真正告诉我任何有用的东西。我该如何正确地做到这一点?

其他信息:我需要以这种方式包装它的原因是因为我使用sklearn.pipeline.FeatureUnion来收集我的特征提取器,然后将它们放入sklearn.pipeline.Pipeline。这样做的结果是,我只能在所有特征提取器中提供单个数据集 - 但不同的提取器需要不同的数据。我的解决方案是以一种易于分离的格式提供数据,并在不同的提取器中过滤不同的部分。如果能够更好地解决这个问题,我也很乐意听到它。

编辑1: 添加**来解压缩dict似乎没有改变任何东西: Screenshot

编辑2: 我刚刚解决了剩下的问题 - 我需要删除构造函数重载。显然,通过尝试调用父构造函数,希望正确启动所有实例变量,我做了完全相反的事情。我的包装器不知道它可以期待什么样的参数。一旦我删除了多余的电话,一切都完美无缺。

1 个答案:

答案 0 :(得分:3)

您忘记解压缩fit_params作为dict传递,并希望将其作为keyword arguments传递,需要解压缩运算符**

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return TfidfVectorizer.fit_transform(self, x, y, **fit_params)  

直接调用TfidfVectorizer' fit_transform的另一件事是你可以通过super方法调用重载版本

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return super(TidfVectorizerWrapper, self).fit_transform(x, y, **fit_params)  

要理解它,请查看以下示例

def foo1(**kargs):
    print kargs

def foo2(**kargs):
    foo1(**kargs)
    print 'foo2'

def foo3(**kargs):
    foo1(kargs)
    print 'foo3'

foo1(a=1, b=2)

打印字典{'a': 1, 'b': 2}

foo2(a=1, b=2)

打印字典和foo2,但

foo3(a=1, b=2)

引发错误,因为我们将{strong>位置参数等同于我们的字典发送到foo1,它不接受这样的事情。但是我们可以做到

def foo4(**kargs):
    foo1(x=kargs)
    print 'foo4'

工作正常,但会打印一个新字典{'x': {'a': 1, 'b': 2}}