我是apache spark的新手,从MLlib的文档中,我找到了一个scala的例子,但我真的不知道scala,有人知道java中的一个例子吗?谢谢!示例代码是
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
// Load and parse the data
val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, parts(1).split(' ').map(x => x.toDouble).toArray)
}
// Building the model
val numIterations = 20
val model = LinearRegressionWithSGD.train(parsedData, numIterations)
// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count
println("training Mean Squared Error = " + MSE)
来自MLlib的文件 谢谢!
答案 0 :(得分:3)
如文件中所示:
所有MLlib的方法都使用Java友好类型,因此您可以导入和 用Scala中的方式调用它们。唯一需要注意的是 这些方法使用Scala RDD对象,而Spark Java API使用 单独的JavaRDD类。您可以将Java RDD转换为Scala 在JavaRDD对象上调用.rdd()。
这并不容易,因为你仍然需要在java中重现scala代码,但它有效(至少在这种情况下)。
话虽如此,这是一个java实现:
public void linReg() {
String master = "local";
SparkConf conf = new SparkConf().setAppName("csvParser").setMaster(
master);
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> data = sc.textFile("mllib/data/ridge-data/lpsa.data");
JavaRDD<LabeledPoint> parseddata = data
.map(new Function<String, LabeledPoint>() {
// I see no ways of just using a lambda, hence more verbosity than with scala
@Override
public LabeledPoint call(String line) throws Exception {
String[] parts = line.split(",");
String[] pointsStr = parts[1].split(" ");
double[] points = new double[pointsStr.length];
for (int i = 0; i < pointsStr.length; i++)
points[i] = Double.valueOf(pointsStr[i]);
return new LabeledPoint(Double.valueOf(parts[0]),
Vectors.dense(points));
}
});
// Building the model
int numIterations = 20;
LinearRegressionModel model = LinearRegressionWithSGD.train(
parseddata.rdd(), numIterations); // notice the .rdd()
// Evaluate model on training examples and compute training error
JavaRDD<Tuple2<Double, Double>> valuesAndPred = parseddata
.map(point -> new Tuple2<Double, Double>(point.label(), model
.predict(point.features())));
// important point here is the Tuple2 explicit creation.
double MSE = valuesAndPred.mapToDouble(
tuple -> Math.pow(tuple._1 - tuple._2, 2)).mean();
// you can compute the mean with this function, which is much easier
System.out.println("training Mean Squared Error = "
+ String.valueOf(MSE));
}
它远非完美,但我希望它能让您更好地理解如何在Mllib文档中使用scala示例。
答案 1 :(得分:1)
package org.apache.spark.examples;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;
import java.util.regex.Pattern;
/**
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
* please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
* org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
public final class JavaHdfsLR {
private static final int D = 10; // Number of dimensions
private static final Random rand = new Random(42);
static void showWarning() {
String warning = "WARN: This is a naive implementation of Logistic Regression " +
"and is given as an example!\n" +
"Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " +
"or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " +
"for more conventional use.";
System.err.println(warning);
}
static class DataPoint implements Serializable {
DataPoint(double[] x, double y) {
this.x = x;
this.y = y;
}
double[] x;
double y;
}
static class ParsePoint implements Function<String, DataPoint> {
private static final Pattern SPACE = Pattern.compile(" ");
@Override
public DataPoint call(String line) {
String[] tok = SPACE.split(line);
double y = Double.parseDouble(tok[0]);
double[] x = new double[D];
for (int i = 0; i < D; i++) {
x[i] = Double.parseDouble(tok[i + 1]);
}
return new DataPoint(x, y);
}
}
static class VectorSum implements Function2<double[], double[], double[]> {
@Override
public double[] call(double[] a, double[] b) {
double[] result = new double[D];
for (int j = 0; j < D; j++) {
result[j] = a[j] + b[j];
}
return result;
}
}
static class ComputeGradient implements Function<DataPoint, double[]> {
private final double[] weights;
ComputeGradient(double[] weights) {
this.weights = weights;
}
@Override
public double[] call(DataPoint p) {
double[] gradient = new double[D];
for (int i = 0; i < D; i++) {
double dot = dot(weights, p.x);
gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
}
return gradient;
}
}
public static double dot(double[] a, double[] b) {
double x = 0;
for (int i = 0; i < D; i++) {
x += a[i] * b[i];
}
return x;
}
public static void printWeights(double[] a) {
System.out.println(Arrays.toString(a));
}
public static void main(String[] args) {
if (args.length < 2) {
System.err.println("Usage: JavaHdfsLR <file> <iters>");
System.exit(1);
}
showWarning();
SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = sc.textFile(args[0]);
JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
int ITERATIONS = Integer.parseInt(args[1]);
// Initialize w to a random value
double[] w = new double[D];
for (int i = 0; i < D; i++) {
w[i] = 2 * rand.nextDouble() - 1;
}
System.out.print("Initial w: ");
printWeights(w);
for (int i = 1; i <= ITERATIONS; i++) {
System.out.println("On iteration " + i);
double[] gradient = points.map(
new ComputeGradient(w)
).reduce(new VectorSum());
for (int j = 0; j < D; j++) {
w[j] -= gradient[j];
}
}
System.out.print("Final w: ");
printWeights(w);
sc.stop();
}
}