如何从java

时间:2016-09-05 10:27:11

标签: java apache-spark-mllib logistic-regression

更新:我尝试使用以下方式生成置信度分数,但它给了我一个例外。我使用下面的代码片段:

double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point)); 

我得到的例外:

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
    at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)
你能帮帮忙吗?似乎权重向量具有比数据向量更多的元素(198)(我生成18个特征)。它们在dot()函数中必须具有相同的长度。

我正在尝试在 Java 中实施一个程序,以便从现有数据集进行训练,并使用Spark MLLib(1.5.0)中提供的逻辑回归算法预测新数据集。我的火车和预测程序如下,我正在使用多类实现。问题是当我执行model.predict(vector)时(注意预测程序中的lrmodel.predict()),我得到预测的标签。但是,如果我需要一个置信度分数呢?我怎么做到的?我已经通过了API,无法找到任何特定的API给出置信度分数。有人可以帮帮我吗?

培训计划(生成 .model 文件)

public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;

        try {
           ...
       SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Train").setMaster(
                            sparkMaster);
            jsc = new JavaSparkContext(sparkConf);
            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

           ...
            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    StringBuilder featureSet =
                                            new StringBuilder();
                                   ...
                                        featureSet.append(i - 2);
                                        featureSet.append(COLON);
                                        featureSet.append(strVal);
                                        featureSet.append(SPACE);
                                    }

                                    return featureSet.toString().trim();
                                }
                            });

            List<String> featureList = featureRDD.collect();
            String featureOutput = StringUtils.join(featureList, NEW_LINE);
            String filePath = basePath + "lr.arff";
            FileUtils.writeStringToFile(new File(filePath), featureOutput,
                    "UTF-8");

            JavaRDD<LabeledPoint> trainingData =
                    MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();

            final LogisticRegressionModel model =
                    new LogisticRegressionWithLBFGS().setNumClasses(18).run(
                            trainingData.rdd());
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(baos);
            oos.writeObject(model);
            oos.flush();
            oos.close();
            FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
                    baos.toByteArray());
            baos.close();

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }

预测程序(使用火车程序生成的 lr.model

    public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;
        try {
            ...
        SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
            jsc = new JavaSparkContext(sparkConf);

            ObjectInputStream objectInputStream =
                    new ObjectInputStream(new FileInputStream(basePath
                            + "lr.model"));
            LogisticRegressionModel lrmodel =
                    (LogisticRegressionModel) objectInputStream.readObject();
            objectInputStream.close();

            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

            ...
            final Broadcast<LogisticRegressionModel> broadcastModel =
                    jsc.broadcast(lrmodel);

            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    LogisticRegressionModel lrModel =
                                            broadcastModel.value();
                                    String row = ((String) arg0);
                                    String[] featureSetArray =
                                            row.split(CSV_SPLITTER);
                                   ...
                                    final Vector vector =
                                            Vectors.dense(doubleArr);
                                    double score = lrModel.predict(vector);
                                   ...
                                    return csvString;
                                }
                            });

            String outputContent =
                    featureRDD
                            .reduce(new org.apache.spark.api.java.function.Function2() {

                                private static final long serialVersionUID =
                                        1212970144641935082L;

                                public Object call(Object arg0, Object arg1)
                                        throws Exception {
                                    ...
                                }

                            });
            ...
            FileUtils.writeStringToFile(new File(basePath
                    + "predicted-sales-data.csv"), sb.toString());
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }
    }
}

2 个答案:

答案 0 :(得分:0)

事实上它似乎不可能。查看源代码,您可以扩展它以返回这些概率。

https://github.com/apache/spark/blob/branch-1.5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

{{1}}

我希望它能帮助您找到解决方法。

答案 1 :(得分:0)

经过多次尝试,我终于设法编写了一个自定义函数来生成置信度分数。它根本不完美,但现在对我有用!

private static double getConfidenceScore(
            final LogisticRegressionModel lrModel, final Vector vector) {
        /* Approach to get confidence scores starts */
        Vector weights = lrModel.weights();
        int numClasses = lrModel.numClasses();
        int dataWithBiasSize = weights.size() / (numClasses - 1);
        boolean withBias = (vector.size() + 1) == dataWithBiasSize;
        double maxMargin = 0.0;
        double margin = 0.0;
        for (int j = 0; j < (numClasses - 1); j++) {
            margin = 0.0;
            for (int k = 0; k < vector.size(); k++) {
                double value = vector.toArray()[k];
                if (value != 0.0) {
                    margin += value
                            * weights.toArray()[(j * dataWithBiasSize) + k];
                }
            }
            if (withBias) {
                margin += weights.toArray()[(j * dataWithBiasSize)
                        + vector.size()];
            }
            if (margin > maxMargin) {
                maxMargin = margin;
            }
        }
        double conf = 1.0 / (1.0 + Math.exp(-maxMargin));
        DecimalFormat twoDForm = new DecimalFormat("#.##");
        double confidenceScore = Double.valueOf(twoDForm.format(conf * 100));
        /* Approach to get confidence scores ends */
        return confidenceScore;
    }