有什么方法可以调整下面给定的Pyspark MLib代码,该代码可以计算分位数以提供更好的性能?

时间:2019-04-24 12:44:18

标签: apache-spark pyspark apache-spark-sql apache-spark-mllib

我正在尝试使用Spark 1.6为各种公司找出表格中每一列的分位数

我在firm_list中有大约5000个条目,在attr_lst中有300个条目。表中的记录数约为200000。

我正在使用10个执行器,每个执行器具有16GB内存。

当前,每次分位数计算大约需要1秒,整个转换大约需要2分钟。以这种速度,它将为5000家公司运行10000分钟。

请让我知道如何优化性能。

    from __future__ import print_function
    from src.utils import sql_service, apa_constant as 
    constant,file_io_service
    from pyspark.sql.functions import monotonicallyIncreasingId
    from pyspark.sql.functions import lit,col,broadcast
    from pyspark.ml.feature import Bucketizer
    from pyspark.ml import Pipeline
    from pyspark.sql.types import StructType
    from pyspark.sql.types import StructField
    from pyspark.sql.types import StringType,DoubleType
    from pyspark.sql.types import *
    from pyspark.ml.feature import *
    from concurrent.futures import *
    from functools import reduce
    from pyspark.sql import DataFrame
    import pyspark
    import numpy as np

    def generate_quantile_reports(spark_context, hive_context, 
          log,attribute_type, **kwargs):

        sql = """describe 
        {}.apa_all_attrs_consortium""".format(kwargs['sem_db'])
        op = hive_context.sql(sql)
        res = op.withColumn("ordinal_position", 
                             monotonicallyIncreasingId())
       res.registerTempTable('attribs')

       attr_lst = hive_context.sql(
                            """select col_name from attribs where 
                              ordinal_position > 24 AND col_name not like 
                              '%vehicle%'
                            AND col_name not like '%cluster_num%'
                            AND col_name not like '%value_seg%' order by 
                            ordinal_position""").collect()

        sql = """select distinct firm_id, firm_name
              from {}.apa_all_attrs_consortium where ud_rep = 1
              and lower(channel) not in ('ria', 'clearing')
              order by firm_id limit 5
              """.format(kwargs['sem_db'])
        dat = hive_context.sql(sql)
        firm_list = dat.collect()
        sql = """select entity_id, cast(firm_id as double), %s from 
                 %s.apa_all_attrs_consortium where ud_rep = 1
                 and lower(channel) not in ('ria', 'clearing') cluster by 
                 entity_id""" % (
                  ", ".join("cast(" + str(attr.col_name) + " as double)" for 
                  attr in attr_lst), kwargs['sem_db'])

        df = hive_context.sql(sql)
        qrtl_list = []
        df.cache()
        df.count()
        counter = 0
       for (fm,fnm) in firm_list:
          df2 = df[df.firm_id == fm]
          df2 = df2.replace(0, np.nan)
          df_main = df2.drop('firm_id')
          counter += 1
          colNames = []
          quartileList = []
          bucketizerList = []
          for var in attr_lst:
              colNames.append(var.col_name)
          jdf = df2._jdf
          bindt = spark_context
                  ._jvm.com.dstsystems.apa.util
                  .DFQuantileFunction.approxQuantile 
                 (jdf,colNames,[0.0,0.25,0.5,0.75,1.0],0.0)
          for i in range(len(bindt)):            
            quartile = sorted(list(set(list(bindt[i]))))
            quartile = [-float("inf")] + quartile
            quartile.insert(len(quartile),float("inf"))
            quartile.insert(len(quartile),float("NaN"))
            df_main = df_main.filter(df_main[colNames[i]].isNotNull())
            bucketizerList                         
            .append(Bucketizer().setInputCol(colNames[i])       
            .setOutputCol("{}_quantile".format(colNames[i]))
            .setSplits(quartile))
          path = " {}/tmpPqtDir/apa_team_broker_quartiles"
                  .format(kwargs['semzone_path'])
          qrtl_list
          .append(Pipeline(stages=bucketizerList)
          .fit(df_main).transform(df_main))
        finalDF = reduce(DataFrame.unionAll, qrtl_list)
        finalDF.repartition(200)
        .write.mode("overwrite").option("header","true").parquet(path)
        df.unpersist()

0 个答案:

没有答案