如何为StreamingLogisticRegressionWithSGD配置没有分类类/标签

时间:2016-05-18 12:40:25

标签: apache-spark-mllib

我是Spark MLlib的新手。我正在尝试实现StreamingLogisticRegressionWithSGD模型。 Spark文档中提供的信息非常少。当我在套接字流上输入2,22-22-22时,我正在

ERROR DataValidators: Classification labels should be 0 or 1. Found 1 invalid labels

据我所知,我希望输入标签为0或1的功能,但我真的想知道我是否可以配置更多标签。 我不知道如何设置StreamingLogisticRegressionWithSGD分类的类数。

谢谢!

代码

package test;

import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.StreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;

public class SLRPOC {

    private static StreamingLogisticRegressionWithSGD slrModel;

    private static int numFeatures = 3;

    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf().setMaster("local[3]").setAppName("SLRPOC");
        SparkContext sc = new SparkContext(sparkConf);
        StreamingContext ssc = new StreamingContext(sc, Durations.seconds(10));
        JavaStreamingContext jssc = new JavaStreamingContext(ssc);

        slrModel = new StreamingLogisticRegressionWithSGD().setStepSize(0.5).setNumIterations(10).setInitialWeights(Vectors.zeros(numFeatures));

        slrModel.trainOn(getDStreamTraining(jssc));
        slrModel.predictOn(getDStreamPrediction(jssc)).foreachRDD(new Function<JavaRDD<Double>, Void>() {

            private static final long serialVersionUID = 5287086933555760190L;

            @Override
            public Void call(JavaRDD<Double> v1) throws Exception {
                List<Double> list = v1.collect();
                for (Double d : list) {
                    System.out.println(d);
                }
                return null;
            }
        });

        jssc.start();
        jssc.awaitTermination();
    }

    public static JavaDStream<LabeledPoint> getDStreamTraining(JavaStreamingContext context) {
        JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9998);

        return lines.map(new Function<String, LabeledPoint>() {

            private static final long serialVersionUID = 1268686043314386060L;

            @Override
            public LabeledPoint call(String data) throws Exception {
                System.out.println("Inside LabeledPoint call : ----- ");
                String arr[] = data.split(",");
                double vc[] = new double[3];
                String vcS[] = arr[1].split("-");
                int i = 0;
                for (String vcSi : vcS) {
                    vc[i++] = Double.parseDouble(vcSi);
                }
                return new LabeledPoint(Double.parseDouble(arr[0]), Vectors.dense(vc));
            }
        });
    }

    public static JavaDStream<Vector> getDStreamPrediction(JavaStreamingContext context) {
        JavaReceiverInputDStream<String> lines = context.socketTextStream("localhost", 9999);

        return lines.map(new Function<String, Vector>() {

            private static final long serialVersionUID = 1268686043314386060L;

            @Override
            public Vector call(String data) throws Exception {
                System.out.println("Inside Vector call : ----- ");
                String vcS[] = data.split("-");
                double vc[] = new double[3];
                int i = 0;
                for (String vcSi : vcS) {
                    vc[i++] = Double.parseDouble(vcSi);
                }
                return Vectors.dense(vc);
            }
        });
    }
}

异常

  

Inside LabeledPoint电话:----- 16/05/18 17:51:10 INFO执行人:   在阶段4.0(TID 4)完成任务0.0。发送953字节的结果   driver 16/05/18 17:51:10 INFO TaskSetManager:完成的任务0.0英寸   在本地主机(1/1)16/05/18 17:51:10 INFO中的8毫秒阶段4.0(TID 4)   TaskSchedulerImpl:删除了TaskSet 4.0,其任务包含所有   完成,从池16/05/18 17:51:10 INFO DAGScheduler:ResultStage   4(在SLRPOC.java:33上的trainOn)在0.009 s 16/05/18 17:51:10完成   信息DAGScheduler:工作6完成:在SLRPOC.java:33上训练,接过   0.019578 s 16/05/18 17:51:10错误DataValidators:分类标签应为0或1.找到1个无效标签16/05/18 17:51:10 INFO   JobScheduler:从作业开始作业流作业1463574070000 ms.1   时间1463574070000 ms 16/05/18 17:51:10错误JobScheduler:   运行作业流作业1463574070000 ms.0时出错   org.apache.spark.SparkException:输入验证失败。在   org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251)     在   org.apache.spark.mllib.regression.StreamingLinearAlgorithm $$ anonfun $ trainOn $ 1.适用(StreamingLinearAlgorithm.scala:94)     在   org.apache.spark.mllib.regression.StreamingLinearAlgorithm $$ anonfun $ trainOn $ 1.适用(StreamingLinearAlgorithm.scala:92)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1 $$ anonfun $ $应用MCV $ SP $ 1.适用$ MCV $ SP(ForEachDStream.scala:42)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1 $$ anonfun $ $应用MCV $ SP $ 1.适用(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1 $$ anonfun $ $应用MCV $ SP $ 1.适用(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1.适用$ MCV $ SP(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1.适用(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1.适用(ForEachDStream.scala:40)     在scala.util.Try $ .apply(Try.scala:161)at   org.apache.spark.streaming.scheduler.Job.run(Job.scala:34)at at   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler $$ anonfun $运行$ 1.适用$ MCV $ SP(JobScheduler.scala:207)     在   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler $$ anonfun $运行$ 1.适用(JobScheduler.scala:207)     在   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler $$ anonfun $运行$ 1.适用(JobScheduler.scala:207)     在scala.util.DynamicVariable.withValue(DynamicVariable.scala:57)at   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler.run(JobScheduler.scala:206)     在   java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)     在   java.util.concurrent.ThreadPoolExecutor中的$ Worker.run(ThreadPoolExecutor.java:615)     at java.lang.Thread.run(Thread.java:745)thread&#34; main&#34;中的异常   org.apache.spark.SparkException:输入验证失败。在   org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.run(GeneralizedLinearAlgorithm.scala:251)     在   org.apache.spark.mllib.regression.StreamingLinearAlgorithm $$ anonfun $ trainOn $ 1.适用(StreamingLinearAlgorithm.scala:94)     在   org.apache.spark.mllib.regression.StreamingLinearAlgorithm $$ anonfun $ trainOn $ 1.适用(StreamingLinearAlgorithm.scala:92)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1 $$ anonfun $ $应用MCV $ SP $ 1.适用$ MCV $ SP(ForEachDStream.scala:42)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1 $$ anonfun $ $应用MCV $ SP $ 1.适用(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1 $$ anonfun $ $应用MCV $ SP $ 1.适用(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.DStream.createRDDWithLocalProperties(DStream.scala:399)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1.适用$ MCV $ SP(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1.适用(ForEachDStream.scala:40)     在   org.apache.spark.streaming.dstream.ForEachDStream $$ anonfun $ 1.适用(ForEachDStream.scala:40)     在scala.util.Try $ .apply(Try.scala:161)at   org.apache.spark.streaming.scheduler.Job.run(Job.scala:34)at at   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler $$ anonfun $运行$ 1.适用$ MCV $ SP(JobScheduler.scala:207)     在   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler $$ anonfun $运行$ 1.适用(JobScheduler.scala:207)     在   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler $$ anonfun $运行$ 1.适用(JobScheduler.scala:207)     在scala.util.DynamicVariable.withValue(DynamicVariable.scala:57)at   org.apache.spark.streaming.scheduler.JobScheduler $ JobHandler.run(JobScheduler.scala:206)     在   java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)     在   java.util.concurrent.ThreadPoolExecutor中的$ Worker.run(ThreadPoolExecutor.java:615)     在java.lang.Thread.run(Thread.java:745)16/05/18 17:51:10 INFO   StreamingContext:从关闭调用stop(stopGracefully = false)   hook 16/05/18 17:51:10 INFO SparkContext:开始工作:foreachRDD at   SLRPOC.java:34 16/05/18 17:51:10 INFO DAGScheduler:工作7完成:   在SLRPOC.java:34的foreachRDD,花了0.000020 s 16/05/18 17:51:10 INFO   JobScheduler:完成作业流作业1463574070000 ms.1来自作业   一组时间1463574070000 ms 16/05/18 17:51:10 INFO ReceiverTracker:   向所有2个接收器发送停止信号16/05/18 17:51:10 INFO   ReceiverSupervisorImpl:收到停止信号16/05/18 17:51:10 INFO   ReceiverSupervisorImpl:使用消息停止接收器:停止   驱动程序:16/05/18 17:51:10 INFO ReceiverSupervisorImpl:调用   receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl:   取消注册接收者1 16/05/18 17:51:10信息   ReceiverSupervisorImpl:收到停止信号16/05/18 17:51:10 INFO   ReceiverSupervisorImpl:使用消息停止接收器:停止   驱动程序:16/05/18 17:51:10 INFO ReceiverSupervisorImpl:调用   receiver onStop 16/05/18 17:51:10 INFO ReceiverSupervisorImpl:   取消注册接收器0 16/05/18 17:51:10 ERROR ReceiverTracker:   流1的注销接收器:由驱动程序16/05/18停止   17:51:10 INFO ReceiverSupervisorImpl:停止接收器1 16/05/18   17:51:10 ERROR ReceiverTracker:流0的注销接收者:   被司机停下来

1 个答案:

答案 0 :(得分:2)

不确定你是否已经解决了这个问题,但你使用的二进制算法只允许2个分类,0或1.如果你想要更多,你需要使用多重分类算法

import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
new LogisticRegressionWithLBFGS().setNumClasses(10)