如何使用spark ml运行随机Forest算法时找到特征重要性和有效性

时间:2016-01-19 07:10:29

标签: r random-forest

我正在尝试在Spark中运行随机森林算法但我需要功能重要性。可以任何人建议如何获得功能重要性。

这是我试过的代码

 public final class RandomForestMlib {
 private static final Pattern SPACE = Pattern.compile(" ");

 @SuppressWarnings("serial")
 public static void main(String[] args) throws Exception {

  /*if (args.length < 1) {
    System.err.println("Usage: JavaWordCount <file>");
    System.exit(1);
  }*/
  //String masterUrl = "spark://192.168.228.128:7077";
  SparkConf sparkConf = new SparkConf().setAppName("GRP").setMaster("local[*]");
  SparkContext ctx = new SparkContext(sparkConf);

  String path = "dataSetnew.txt";
  JavaRDD < LabeledPoint > rdd = MLUtils.loadLibSVMFile(ctx, path).toJavaRDD();

  // RDD<LabeledPoint> rddnew =  rdd.toRDD(null);
  SQLContext sqlContext = new org.apache.spark.sql.SQLContext(ctx);
  //RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt");
  DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);

  // Index labels, adding metadata to the label column.
  // Fit on whole dataset to include all labels in index.
  StringIndexerModel labelIndexer = new StringIndexer()
   .setInputCol("label")
   .setOutputCol("indexedLabel")
   .fit(data);
  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()
   .setInputCol("features")
   .setOutputCol("indexedFeatures")
   .fit(data);





  Map < Integer, Map < Double, Integer >> categoryMaps = featureIndexer.javaCategoryMaps();
  System.out.print("Chose " + categoryMaps.size() + " categorical features:");

  for (Integer feature: categoryMaps.keySet()) {
   System.out.print(" " + feature);

   Map < Double, Integer > val = categoryMaps.get(feature);
   System.out.print(" ");
   Set < Double > ctr = val.keySet();
   Iterator < Double > itr = ctr.iterator();
   for (; itr.hasNext();) {
    System.out.println("value :" + val.get(itr.next()));

   }
  }
  System.out.println();

  // Split the data into training and test sets (30% held out for testing)
  DataFrame[] splits = data.randomSplit(new double[] {
   0.7,
   0.3
  });
  DataFrame trainingData = splits[0];
  DataFrame testData = splits[1];

  //data.show();


  // Train a RandomForest model.
  RandomForestClassifier rf = new RandomForestClassifier()
   .setLabelCol("indexedLabel")
   .setFeaturesCol("indexedFeatures");

  // Convert indexed labels back to original labels.
  IndexToString labelConverter = new IndexToString()
   .setInputCol("prediction")
   .setOutputCol("predictedLabel")
   .setLabels(labelIndexer.labels());

  // Chain indexers and forest in a Pipeline
  Pipeline pipeline = new Pipeline()
   .setStages(new PipelineStage[] {
    labelIndexer,
    featureIndexer,
    rf,
    labelConverter
   });

  // Train model.  This also runs the indexers.
  PipelineModel model = pipeline.fit(trainingData);

  // Make predictions.
  DataFrame predictions = model.transform(testData);

  // Select example rows to display.
  predictions.select("predictedLabel", "label", "features").show(5);

  // Select (prediction, true label) and compute test error
  /* MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("precision");
    double accuracy = evaluator.evaluate(predictions);
    System.out.println("Test Error = " + (1.0 - accuracy));
*/
  RandomForestClassificationModel rfModel =
   (RandomForestClassificationModel)(model.stages()[2]);
  // System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());



  System.out.println("Stage 1" + model.stages()[1]);

  System.out.println("Stage 2" + model.stages()[0]);


  Transformer[] trans = model.stages();

  System.out.println("length of the array :" + trans.length);
  /* for(int i = 0 ; i <trans.length ; i++ ){

         System.out.println("length :"+i+1);

    }
*/

  Vector featureImp = rfModel.featureImportances();

  Vector denseVecnew = Vectors.dense(112, 110, 0, 0, 0, 0, 0, 0, 0, 0, 0);
  double pred = rfModel.predict(denseVecnew);
  System.out.println("Prediction : " + pred);
  System.out.println(featureImp);
  System.out.println("feature Size :" + featureImp.size());
  System.out.println("featureIndexer :" + featureIndexer.numFeatures());
  double[] importanceArray = featureImp.toArray();
  double sum = 0;
  for (int i = 0; i < importanceArray.length; i++) {

   sum = sum + importanceArray[i];
   System.out.println("importance for index " + i + "  :  " + importanceArray[i]);
  }



  System.out.println(" sum  = " + sum);

  ctx.stop();

 }

0 个答案:

没有答案