如何从CrossValidator MultilayerPerceptronClasiffier的bestModel的最佳模型中获取权重-Spark mllib-Pyspark

时间:2019-11-25 17:37:26

标签: apache-spark pyspark

我设法获得了在 MultilayerPerceptronClassifier 中使用 CrossValidator 的模型的最佳模型和最佳参数,但是我无法获得权重,我已经做了一些尝试,但不能。我正在使用Zeppelin和Spark 2.3.2。

layers=[4,4,4]

mlpc = MultilayerPerceptronClassifier(layers=layers, seed=12345L, blockSize=624)
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

pipelineMLPC = Pipeline(stages=[mlpc])

paramGrid = ParamGridBuilder() \
    .addGrid(mlpc.maxIter, range(135,136)) \
    .addGrid(mlpc.tol, [1E-3,1E-4]) \
    .addGrid(mlpc.solver, ["l-bfgs"]) \
    .build()

cvPipeline = CrossValidator(estimator = pipelineMLPC,
                    estimatorParamMaps = paramGrid,
                    evaluator = MulticlassClassificationEvaluator(metricName="accuracy"),
                    numFolds = 2)

cvModelPipeline = cvPipeline.fit(train);


#Make predictions on test
prediction = cvModelPipeline.transform(test)

#Obtiene el Accuracy del mejor modelo
predictionAndLabels = prediction.select("prediction", "label")
accuracy_test = evaluator.evaluate(predictionAndLabels)
print("Accuracy test = " + str(accuracy_test))

#Obtiene el mejor modelo 
bestModelMLPC = cvModelPipeline.bestModel.stages[-1]._java_obj.parent()

#Obtiene los mejores parámetros del mejor modelo
bestMaxIter = bestModelMLPC.getMaxIter()
bestTol = bestModelMLPC.getTol()
bestSolver = bestModelMLPC.getSolver()
bestBlockSize = bestModelMLPC.getBlockSize()
bestSeed = bestModelMLPC.getSeed()


bestWeights = bestModelMLPC.weights
print(bestWeights)
#bestModelMLPC.weights 
#bestModelMLPC.weight
#bestModelMLPC.getClassifier().weights
#bestModelMLPC.getClassifier().weight

print("Best Params: maxIter = " + str(bestMaxIter) + "; tol = " + str(bestTol) + "; solver = " + str(bestSolver) + "; blockSize = " + str(bestBlockSize) + ";  seed = " + str(bestSeed) + "; initialWeights = " + str(bestWeights))

这样我得到了一些java,但是没有得到我需要的权重值。任何人都知道我如何获得它。

Accuracy test = 0.392935019625
<py4j.java_gateway.JavaMember object at 0x7efe75d45190>
Best Params: maxIter = 135; tol = 0.001; solver = l-bfgs; blockSize = 624;  seed = 12345; initialWeights = <py4j.java_gateway.JavaMember object at 0x7efe75d45190>

0 个答案:

没有答案