将Apache Spark决策树分类器用于多类分类时出错

时间:2016-11-25 04:20:33

标签: java apache-spark classification apache-spark-mllib decision-tree

我正在尝试根据从移动设备获取的传感器数据对用户活动进行分类。数据集包含用户ID,传感器数据和活动。活动以整数形式给出,有12类活动。下面给出了我用于活动识别分类问题的代码。我使用Apache Spark决策树进行多类分类问题。

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

public class DecisionTreeClass {
    public  static void main(String args[]){
        SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeClass").setMaster("local[2]");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);


        // Load and parse the data file.
        String datapath = "/home/thamali/Desktop/Project/csv/libsvm/trainlib.txt";
        JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
        // Split the data into training and test sets (30% held out for testing)
        JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
        JavaRDD<LabeledPoint> trainingData = splits[0];
        JavaRDD<LabeledPoint> testData = splits[1];

        // Set parameters.
        //  Empty categoricalFeaturesInfo indicates all features are continuous.
        Integer numClasses = 12;
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap();
        String impurity = "gini";
        Integer maxDepth = 5;
        Integer maxBins = 32;

        // Train a DecisionTree model for classification.
        final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
                categoricalFeaturesInfo, impurity, maxDepth, maxBins);

        // Evaluate model on test instances and compute test error
        JavaPairRDD<Double, Double> predictionAndLabel =
                testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
                    @Override
                    public Tuple2<Double, Double> call(LabeledPoint p) {
                        return new Tuple2(model.predict(p.features()), p.label());
                    }
                });
        Double testErr =
                1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
                    @Override
                    public Boolean call(Tuple2<Double, Double> pl) {
                        return !pl._1().equals(pl._2());
                    }
                }).count() / testData.count();

        System.out.println("Test Error: " + testErr);
        System.out.println("Learned classification tree model:\n" + model.toDebugString());

        // Save and load model
        model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
        DecisionTreeModel sameModel = DecisionTreeModel
                .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
// $example off$
    }

}

使用上面的代码时,我得到以下异常。有人可以帮我解决问题。

Caused by: java.lang.IllegalArgumentException: GiniAggregator given label 17.0 but requires label < numClasses (= 12).
    at org.apache.spark.mllib.tree.impurity.GiniAggregator.update(Gini.scala:92)
    at org.apache.spark.ml.tree.impl.DTStatsAggregator.update(DTStatsAggregator.scala:109)
    at org.apache.spark.ml.tree.impl.RandomForest$.orderedBinSeqOp(RandomForest.scala:326)
    at org.apache.spark.ml.tree.impl.RandomForest$.org$apache$spark$ml$tree$impl$RandomForest$$nodeBinSeqOp$1(RandomForest.scala:416)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$org$apache$spark$ml$tree$impl$RandomForest$$binSeqOp$1$1.apply(RandomForest.scala:441)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$org$apache$spark$ml$tree$impl$RandomForest$$binSeqOp$1$1.apply(RandomForest.scala:439)
    at scala.collection.immutable.Map$Map1.foreach(Map.scala:109)
    at org.apache.spark.ml.tree.impl.RandomForest$.org$apache$spark$ml$tree$impl$RandomForest$$binSeqOp$1(RandomForest.scala:439)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9$$anonfun$apply$9.apply(RandomForest.scala:532)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9$$anonfun$apply$9.apply(RandomForest.scala:532)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9.apply(RandomForest.scala:532)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9.apply(RandomForest.scala:521)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:785)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:785)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:79)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:47)
    at org.apache.spark.scheduler.Task.run(Task.scala:86)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)

1 个答案:

答案 0 :(得分:0)

更改为:

Integer numClasses = 17;