sklearn管道中是否存在数据类型约束?

时间:2019-10-31 14:53:06

标签: python scikit-learn pipeline

我正在使用扫描的文档作为输入来创建文档分类器。我已经有一个训练有素的模型,但是我目前正在重构代码并构建独立的块,以便更好地维护和更轻松地理解代码。我对Python相对较新,因此我可能无法正确使用sklearn-Pipelines,因为我将图像路径用作输入而不是DataFrames,但是我读到的唯一要求是使用fit和transform方法构建自定义的转换器,所以我正在尝试现在走。

我想创建一个像sklearn一样的管道。我按照《使用Scikit-Learn和Tensor-flow进行动手机器学习》一书中解释的结构创建了不同的块。我的第一个块是图像加载器(输入:图像路径,输出:PIL.Image类型),图像处理块,光学字符识别块,文本转换器块和分类器将跟随该图像加载器。使管道正常工作的第一个尝试非常简单:我试图使其仅与第一个块一起工作,但是我已经偶然发现了一个主要问题。

我粘贴ImageLoader类以及如何创建管道:

class ImageLoader(BaseEstimator, TransformerMixin):
'''
this class takes a path as input an return a PIL image
'''

def __init__(self, dpi_max=400, dpi_min=200):
    self.dpi_max = dpi_max
    self.dpi_min = dpi_min
    self.path = None

def fit(self, X, y=None):
    assert os.path.isfile(X), 'e: file does not exist'
    self.path = X

    return self

def transform(self, X, y = None):
    print("processing {}...".format(self.path))
    try:
        if '.pdf' in self.path:
            try:
                with tempfile.TemporaryDirectory() as path:
                    image = convert_from_path(self.path, dpi=self.dpi_max, output_folder=path)[0]

            except DecompressionBombError:
                print('using min_dpi')
                with tempfile.TemporaryDirectory() as path:
                    image = convert_from_path(self.path, dpi=self.dpi_min, output_folder=path)[0]
        else:
            image = Image.open(self.path)

    except Exception as e:
        print('ERROR: ', str(e))
        return str(e)

    if image.height > MAX_INT16 or image.width > MAX_INT16:
        print('warning: this image is too big for tesseract algorithm')

    return image

def fit_transform(self, X, y = None):
    self.fit(X)
    return self.transform()


class TestPipelineDocumentClassifier:

    def test_try_pipeline(self):
        image = cv2.imread('tests/data/test_Schwerbehindertenausweis.jpg')
        X = 'tests/data/test_Schwerbehindertenausweis.jpg'

        pipeline = Pipeline([('image_loader', ImageLoader)])

        image = pipeline.fit_transform(X)

运行测试“ test_try_pipeline”时,出现以下错误消息:

def fit_transform(self, X, y = None):
 self.fit(X)
AttributeError: 'str' object has no attribute 'fit'

当解释器输入我的ImageLoader的fit_transform(self,X,y = None)时,问题就出现了,参数“ self”包含路径而不是对对象本身的引用。这个怎么可能?我进行了调试,发现在sklearn文件pipeline.py,第393行中:

return last_step.fit_transform(Xt, y, **fit_params)

Xt包含图像的路径,并且'last_step'是ImageLoader类型的对象,因此,这里的一切看起来都很好,但是如果我进入该方法,那么我会发现自己处于ImageLoader :: fit_transform(self,X,y = 0),其中参数“ self”包含路径,“ X”为空。为什么会这样呢?为什么“自我”不是对ImageLoader对象的引用,而“ X”包含路径?

由于某种原因,数据类型可能会成为问题吗?如果是这样,是否有一种方法可以创建类似的管道,其中块的输出是以下内容的输入?比做起来更优雅:

path = 'path_to_image'
ImageProcessor().fit_transform(ImageLoader().fit_transform(path))

非常感谢您的帮助:)

0 个答案:

没有答案