我正在尝试使用Spark 1.6为各种公司找出表格中每一列的分位数
我在firm_list中有大约5000个条目,在attr_lst中有300个条目。表中的记录数约为200000。
我正在使用10个执行器,每个执行器具有16GB内存。
当前,每次分位数计算大约需要1秒,整个转换大约需要2分钟。以这种速度,它将为5000家公司运行10000分钟。
请让我知道如何优化性能。
from __future__ import print_function
from src.utils import sql_service, apa_constant as
constant,file_io_service
from pyspark.sql.functions import monotonicallyIncreasingId
from pyspark.sql.functions import lit,col,broadcast
from pyspark.ml.feature import Bucketizer
from pyspark.ml import Pipeline
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType,DoubleType
from pyspark.sql.types import *
from pyspark.ml.feature import *
from concurrent.futures import *
from functools import reduce
from pyspark.sql import DataFrame
import pyspark
import numpy as np
def generate_quantile_reports(spark_context, hive_context,
log,attribute_type, **kwargs):
sql = """describe
{}.apa_all_attrs_consortium""".format(kwargs['sem_db'])
op = hive_context.sql(sql)
res = op.withColumn("ordinal_position",
monotonicallyIncreasingId())
res.registerTempTable('attribs')
attr_lst = hive_context.sql(
"""select col_name from attribs where
ordinal_position > 24 AND col_name not like
'%vehicle%'
AND col_name not like '%cluster_num%'
AND col_name not like '%value_seg%' order by
ordinal_position""").collect()
sql = """select distinct firm_id, firm_name
from {}.apa_all_attrs_consortium where ud_rep = 1
and lower(channel) not in ('ria', 'clearing')
order by firm_id limit 5
""".format(kwargs['sem_db'])
dat = hive_context.sql(sql)
firm_list = dat.collect()
sql = """select entity_id, cast(firm_id as double), %s from
%s.apa_all_attrs_consortium where ud_rep = 1
and lower(channel) not in ('ria', 'clearing') cluster by
entity_id""" % (
", ".join("cast(" + str(attr.col_name) + " as double)" for
attr in attr_lst), kwargs['sem_db'])
df = hive_context.sql(sql)
qrtl_list = []
df.cache()
df.count()
counter = 0
for (fm,fnm) in firm_list:
df2 = df[df.firm_id == fm]
df2 = df2.replace(0, np.nan)
df_main = df2.drop('firm_id')
counter += 1
colNames = []
quartileList = []
bucketizerList = []
for var in attr_lst:
colNames.append(var.col_name)
jdf = df2._jdf
bindt = spark_context
._jvm.com.dstsystems.apa.util
.DFQuantileFunction.approxQuantile
(jdf,colNames,[0.0,0.25,0.5,0.75,1.0],0.0)
for i in range(len(bindt)):
quartile = sorted(list(set(list(bindt[i]))))
quartile = [-float("inf")] + quartile
quartile.insert(len(quartile),float("inf"))
quartile.insert(len(quartile),float("NaN"))
df_main = df_main.filter(df_main[colNames[i]].isNotNull())
bucketizerList
.append(Bucketizer().setInputCol(colNames[i])
.setOutputCol("{}_quantile".format(colNames[i]))
.setSplits(quartile))
path = " {}/tmpPqtDir/apa_team_broker_quartiles"
.format(kwargs['semzone_path'])
qrtl_list
.append(Pipeline(stages=bucketizerList)
.fit(df_main).transform(df_main))
finalDF = reduce(DataFrame.unionAll, qrtl_list)
finalDF.repartition(200)
.write.mode("overwrite").option("header","true").parquet(path)
df.unpersist()