您似乎正在尝试从广播变量,操作或转换中引用SparkContext

时间:2019-09-26 14:07:11

标签: pyspark apache-spark-sql rdd apache-spark-ml

我想为每个键训练模型,但是我有很多键。我正在尝试RDD和数据框方法:使用rdd.map并使用数据框方法将mllib函数捆绑在一个函数中,以针对每个键训练模型并获取相应的模型测试结果。我看过this个帖子,但对我没有帮助。

def metrics_per_key(key):
   import pyspark.sql.functions as F
   df = spark.read.csv('path to csv', header=True, inferSchema=True)  
   df = df.withColumn('label', df['rank'] - 1)   
   df = df.withColumn('day_part', F.when(df.hour < 3, 'g1').when(df.hour < 6, 'g2').when(df.hour < 9, 'g3').when(df.hour < 12, 'g4').when(df.hour < 15, 'g5').when(df.hour < 18, 'g6').when(df.hour < 21, 'g7').otherwise('g8'))
   df_filtered = df.filter(F.col('key') == key).drop('key')

   stringIndexer_day = StringIndexer(inputCol="day", outputCol="dayIndex")
   stringIndexer_day_hr = StringIndexer(inputCol="day_hour", outputCol="day_hourIndex")
   stringIndexer_day_part = StringIndexer(inputCol="day_part", outputCol="day_partIndex")
   model_day = stringIndexer_day.fit(df)
   indexed_day = model_day.transform(df)
   model_day_hour = stringIndexer_day_hr.fit(indexed_day)
   indexed_all = model_day_hour.transform(indexed_day)
   model_day_part = stringIndexer_day_part.fit(indexed_all)
   indexed_all_including_day_part = model_day_part.transform(indexed_all)
   encoder_day = OneHotEncoder(inputCol="dayIndex", outputCol="dayIndexVec")
   encoder_dayHour = OneHotEncoder(inputCol="day_hourIndex", outputCol="day_hourIndexVec")
   encoder_hour = OneHotEncoder(inputCol="hour", outputCol="hourIndexVec")
   encoder_day_part = OneHotEncoder(inputCol="day_partIndex", outputCol="day_partIndexVec")
   encoded_day = encoder_day.transform(indexed_all_including_day_part)
   encode_day_dayHour = encoder_dayHour.transform(encoded_day)
   encoded_all = encoder_hour.transform(encode_day_dayHour)
   encoded_all_with_day_part = encoder_day_part.transform(encoded_all)
   assembler = VectorAssembler(inputCols=["hourIndexVec", "dayIndexVec", "day_hourIndexVec", "day_partIndexVec","bid"], outputCol="features")
   assembled = assembler.transform(encoded_all_with_day_part)
   assembled = assembled.select(["key","label","features"])
   assembled.persist()
   labelIndexer = StringIndexer(inputCol="label", 
   outputCol="indexedLabel").fit(assembled)
   featureIndexer = VectorIndexer(inputCol="features", 
   outputCol="indexedFeatures", maxCategories=4).fit(assembled)
   (trainingData, testData) = assembled.randomSplit([0.8, 0.2], seed = 0)
   rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")
   labelConverter = IndexToString(inputCol="prediction", 
   outputCol="predictedLabel",labels=labelIndexer.labels)
   pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter])

    paramGrid_rf = ParamGridBuilder().addGrid(rf.maxDepth, [10,20,25,30]).addGrid(rf.numTrees, [10,20,30, 40, 50]).addGrid(rf.maxBins, [16, 32,48,64]).build()
   crossval = CrossValidator(estimator=pipeline,estimatorParamMaps=paramGrid_rf,evaluator=MulticlassClassificationEvaluator(),numFolds=5,parallelism = 10)
   model = crossval.fit(trainingData)
   predictions = model.transform(testData)
   precision = MulticlassClassificationEvaluator(labelCol="indexedLabel", 
   predictionCol="prediction", 
   metricName="weightedPrecision").evaluate(predictions)
   recall = MulticlassClassificationEvaluator(labelCol="indexedLabel", 
 predictionCol="prediction", metricName="weightedRecall").evaluate(predictions)
  accuracy = MulticlassClassificationEvaluator(labelCol="indexedLabel", 
  predictionCol="prediction", metricName="accuracy").evaluate(predictions)
  f1 = MulticlassClassificationEvaluator(labelCol="indexedLabel", 
  predictionCol="prediction", metricName="f1").evaluate(predictions)
  return {'f1_test':f1, 'precision_test':precision, 'accuracy_test':accuracy, 'recall_test':recall}

我想将以上功能与rdd.map一起使用

df = spark.read.csv('path to csv', header=True, inferSchema=True) 
keys = df.select('key').rdd
results = keys.map(metrics_per_key).collect()
  

异常:似乎您正在尝试从广播变量,操作或转换引用SparkContext。 SparkContext只能在驱动程序上使用,而不能在工作程序上运行的代码中使用。有关更多信息,请参阅SPARK-5063。

0 个答案:

没有答案