在实例化时将SparkSession传递给自定义Transformer

时间:2019-05-17 17:12:40

标签: python pyspark

我正在为我的Pyspark项目编写自己的变压器,但遇到了一个问题:

如果我在将要使用它的模块/笔记本中直接编写变压器,那么一切都会很好;例如:

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import (HasInputCol, HasInputCols, HasOutputCol, 
    HasOutputCols, Param)
from pyspark.sql import (SparkSession, types, functions as funcs)

spark = SparkSession.builder.appName('my_session').getOrCreate()

# My Custom Transformer 1:
class MyTransformerOne(Transformer, HasInputCol, HasOutputCol):
    @keyword_only
    def __init__(self, inputCol='my_input', outputCol='my_output'):
        super(MyTransformerOne, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol='my_input', outputCol='my_output'):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def _transform(self, dataset):
        # I need a little dataframe here to perform some tasks:
        df = spark.createDataFrame(
            [
                {'col1': 1, 'col2': 'A'}, {'col1': 2, 'col2': 'B'}
            ],
            schema = types.StructType([
                types.StructField('col1', types.IntegerType(), True),
                types.StructField('col2', types.StringType(), True),
            ])
        )
        pass # Lots of other things happen here... the little dataframe above
             # is joined with the 'to be transformed' dataset and some columns
             # are calculated.
        return final_dataset

df = MyTransformerOne().fit(input_df).transform(input_df)
# This works Ok

我有7个这样的转换器,所以我想将它们存储在一个单独的模块中(我们将其命名为my_transformers.py,我想:“好吧,我需要一个SparkSession对象来实现此目的工作...所以我们把它放在__init__方法上。”但是它不起作用:

"""
my_transformers.py
"""

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import (HasInputCol, HasInputCols, HasOutputCol, 
    HasOutputCols, Param)
from pyspark.sql import (types, functions as funcs)

class MyTransformerOne(Transformer, HasInputCol, HasOutputCol):
    @keyword_only
    def __init__(self, spark=None, inputCol='my_input', output_col='my_output'):
        super(MyTransformerOne, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol='my_input', outputCol='my_output'):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def _transform(self, dataset):
        # Let's use the instance attribute to create the dataframe
        df = self.spark.createDataframe(...)
        # ... same as above

然后,在其他模块/笔记本上:

import my_transformers

# ... Create a spark session, load the data, etcetera
df = my_transformers.MyTransformerOne().fit(input_df).transform(input_df)

此操作失败:

AttributeError: 'MyTransformerOne' object has no attribute 'spark'

我在这里迷路了。所以,我的问题是:

  1. 我可以将SparkSession对象传递给自定义转换器对象吗?
  2. 如何进行这项工作?我确实需要在转换器类的内部 中创建这些数据框(在转换器外部创建它们是没有意义的,因为它们将不会用于其他任何任务)。

你能指出我正确的方向吗?

1 个答案:

答案 0 :(得分:0)

原来比我想的要容易!

我发现了this answer:我可以在班级内致电SparkSession.builder.getOrCreate()。导入my_transformers模块后,每次我需要使用Spark Session时,都只需要将该行添加到我的方法中即可。

因此,完整的代码是这样的:

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import (HasInputCol, HasInputCols, HasOutputCol, 
    HasOutputCols, Param)
from pyspark.sql import (SparkSession, types, functions as funcs)

# My Custom Transformer 1:
class MyTransformerOne(Transformer, HasInputCol, HasOutputCol):
    @keyword_only
    def __init__(self, inputCol='my_input', outputCol='my_output'):
        super(MyTransformerOne, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol='my_input', outputCol='my_output'):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def _transform(self, dataset):
        # HERE! I get the active SparkSession.
        spark = SparkSession.builder.getOrCreate()

        # I need a little dataframe here to perform some tasks:
        df = spark.createDataFrame(
            [
                {'col1': 1, 'col2': 'A'}, {'col1': 2, 'col2': 'B'}
            ],
            schema = types.StructType([
                types.StructField('col1', types.IntegerType(), True),
                types.StructField('col2', types.StringType(), True),
            ])
        )
        pass # Lots of other things happen here... the little dataframe above
             # is joined with the 'to be transformed' dataset and some columns
             # are calculated.
        return final_dataset

df = MyTransformerOne().fit(input_df).transform(input_df)

我将在此处保留此帖子,并将我的问题标记为重复。