如何使用apache spark的MLlib的线性回归?

时间:2014-05-29 20:07:55

标签: java apache-spark apache-spark-mllib

我是apache spark的新手,从MLlib的文档中,我找到了一个scala的例子,但我真的不知道scala,有人知道java中的一个例子吗?谢谢!示例代码是

import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint

// Load and parse the data
val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
val parsedData = data.map { line =>
  val parts = line.split(',')
  LabeledPoint(parts(0).toDouble, parts(1).split(' ').map(x => x.toDouble).toArray)
}

// Building the model
val numIterations = 20
val model = LinearRegressionWithSGD.train(parsedData, numIterations)

// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ +     _)/valuesAndPreds.count
println("training Mean Squared Error = " + MSE)

来自MLlib的文件 谢谢!

2 个答案:

答案 0 :(得分:3)

如文件中所示:

  

所有MLlib的方法都使用Java友好类型,因此您可以导入和   用Scala中的方式调用它们。唯一需要注意的是   这些方法使用Scala RDD对象,而Spark Java API使用   单独的JavaRDD类。您可以将Java RDD转换为Scala   在JavaRDD对象上调用.rdd()。

这并不容易,因为你仍然需要在java中重现scala代码,但它有效(至少在这种情况下)。

话虽如此,这是一个java实现:

public void linReg() {
    String master = "local";
    SparkConf conf = new SparkConf().setAppName("csvParser").setMaster(
            master);
    JavaSparkContext sc = new JavaSparkContext(conf);
    JavaRDD<String> data = sc.textFile("mllib/data/ridge-data/lpsa.data");
    JavaRDD<LabeledPoint> parseddata = data
            .map(new Function<String, LabeledPoint>() {
            // I see no ways of just using a lambda, hence more verbosity than with scala
                @Override
                public LabeledPoint call(String line) throws Exception {
                    String[] parts = line.split(",");
                    String[] pointsStr = parts[1].split(" ");
                    double[] points = new double[pointsStr.length];
                    for (int i = 0; i < pointsStr.length; i++)
                        points[i] = Double.valueOf(pointsStr[i]);
                    return new LabeledPoint(Double.valueOf(parts[0]),
                            Vectors.dense(points));
                }
            });

    // Building the model
    int numIterations = 20;
    LinearRegressionModel model = LinearRegressionWithSGD.train(
    parseddata.rdd(), numIterations); // notice the .rdd()

    // Evaluate model on training examples and compute training error
    JavaRDD<Tuple2<Double, Double>> valuesAndPred = parseddata
            .map(point -> new Tuple2<Double, Double>(point.label(), model
                    .predict(point.features())));
    // important point here is the Tuple2 explicit creation.

    double MSE = valuesAndPred.mapToDouble(
            tuple -> Math.pow(tuple._1 - tuple._2, 2)).mean();
    // you can compute the mean with this function, which is much easier
    System.out.println("training Mean Squared Error = "
            + String.valueOf(MSE));
}

它远非完美,但我希望它能让您更好地理解如何在Mllib文档中使用scala示例。

答案 1 :(得分:1)

package org.apache.spark.examples;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;
import java.util.regex.Pattern;

/**
 * Logistic regression based classification.
 *
 * This is an example implementation for learning how to use Spark. For more conventional use,
 * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
 * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
 */
public final class JavaHdfsLR {

  private static final int D = 10;   // Number of dimensions
  private static final Random rand = new Random(42);

  static void showWarning() {
    String warning = "WARN: This is a naive implementation of Logistic Regression " +
            "and is given as an example!\n" +
            "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " +
            "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " +
            "for more conventional use.";
    System.err.println(warning);
  }

  static class DataPoint implements Serializable {
    DataPoint(double[] x, double y) {
      this.x = x;
      this.y = y;
    }

    double[] x;
    double y;
  }

  static class ParsePoint implements Function<String, DataPoint> {
    private static final Pattern SPACE = Pattern.compile(" ");

    @Override
    public DataPoint call(String line) {
      String[] tok = SPACE.split(line);
      double y = Double.parseDouble(tok[0]);
      double[] x = new double[D];
      for (int i = 0; i < D; i++) {
        x[i] = Double.parseDouble(tok[i + 1]);
      }
      return new DataPoint(x, y);
    }
  }

  static class VectorSum implements Function2<double[], double[], double[]> {
    @Override
    public double[] call(double[] a, double[] b) {
      double[] result = new double[D];
      for (int j = 0; j < D; j++) {
        result[j] = a[j] + b[j];
      }
      return result;
    }
  }

  static class ComputeGradient implements Function<DataPoint, double[]> {
    private final double[] weights;

    ComputeGradient(double[] weights) {
      this.weights = weights;
    }

    @Override
    public double[] call(DataPoint p) {
      double[] gradient = new double[D];
      for (int i = 0; i < D; i++) {
        double dot = dot(weights, p.x);
        gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
      }
      return gradient;
    }
  }

  public static double dot(double[] a, double[] b) {
    double x = 0;
    for (int i = 0; i < D; i++) {
      x += a[i] * b[i];
    }
    return x;
  }

  public static void printWeights(double[] a) {
    System.out.println(Arrays.toString(a));
  }

  public static void main(String[] args) {

    if (args.length < 2) {
      System.err.println("Usage: JavaHdfsLR <file> <iters>");
      System.exit(1);
    }

    showWarning();

    SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    JavaRDD<String> lines = sc.textFile(args[0]);
    JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
    int ITERATIONS = Integer.parseInt(args[1]);

    // Initialize w to a random value
    double[] w = new double[D];
    for (int i = 0; i < D; i++) {
      w[i] = 2 * rand.nextDouble() - 1;
    }

    System.out.print("Initial w: ");
    printWeights(w);

    for (int i = 1; i <= ITERATIONS; i++) {
      System.out.println("On iteration " + i);

      double[] gradient = points.map(
        new ComputeGradient(w)
      ).reduce(new VectorSum());

      for (int j = 0; j < D; j++) {
        w[j] -= gradient[j];
      }

    }

    System.out.print("Final w: ");
    printWeights(w);
    sc.stop();
  }
}