Apache Spark Machine Learning - 无法让Estimator示例工作

时间:2015-09-16 15:56:27

标签: java maven apache-spark

我很难从Spark文档中获取任何示例机器学习代码,并实际上让它们作为Java程序运行。无论是我对Java,Maven,Spark(或者很可能是全部三种)的有限知识,我都找不到有用的解释。

拿这个example。为了尝试实现这一点,我使用了以下项目结构

.
├── pom.xml
└── src
    └── main
        └── java
            └── SimpleEstimator.java

Java 文件如下所示

import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;


public class SimpleEstimator {
  public static void main(String[] args) {
    DataFrame training = sqlContext.createDataFrame(Arrays.asList(
      new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
      new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
      new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
      new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))
    ), LabeledPoint.class);

    LogisticRegression lr = new LogisticRegression();
    System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");

    lr.setMaxIter(10)
      .setRegParam(0.01);

    LogisticRegressionModel model1 = lr.fit(training);

    System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());

    ParamMap paramMap = new ParamMap()
      .put(lr.maxIter().w(20)) // Specify 1 Param.
      .put(lr.maxIter(), 30) // This overwrites the original maxIter.
      .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.

    ParamMap paramMap2 = new ParamMap()
      .put(lr.probabilityCol().w("myProbability")); // Change output column name
    ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);

    LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
    System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());

    DataFrame test = sqlContext.createDataFrame(Arrays.asList(
      new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
      new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
      new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
    ), LabeledPoint.class);


    DataFrame results = model2.transform(test);
    for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) {
      System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
          + ", prediction=" + r.get(3));
    }
  }
}

pom 文件如下

<project>
  <groupId>edu.berkeley</groupId>
  <artifactId>simple-estimator</artifactId>
  <modelVersion>4.0.0</modelVersion>
  <name>Simple Estimator</name>
  <packaging>jar</packaging>
  <version>1.0</version>
  <dependencies>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.11</artifactId>
      <version>1.5.0</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_2.11</artifactId>
      <version>1.5.0</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_2.11</artifactId>
      <version>1.5.0</version>
    </dependency>
  </dependencies>
</project>

如果我从该目录的根目录运行mvn package,我会收到以下错误

[INFO] Scanning for projects...
[INFO]
[INFO] ------------------------------------------------------------------------
[INFO] Building Simple Estimator 1.0
[INFO] ------------------------------------------------------------------------
[INFO]
[INFO] --- maven-resources-plugin:2.6:resources (default-resources) @ simple-estimator ---
[WARNING] Using platform encoding (UTF-8 actually) to copy filtered resources, i.e. build is platform dependent!
[INFO] skip non existing resourceDirectory /Users/philip/study/spark/estimator/src/main/resources
[INFO]
[INFO] --- maven-compiler-plugin:3.1:compile (default-compile) @ simple-estimator ---
[INFO] Changes detected - recompiling the module!
[WARNING] File encoding has not been set, using platform encoding UTF-8, i.e. build is platform dependent!
[INFO] Compiling 1 source file to /Users/philip/study/spark/estimator/target/classes
[INFO] -------------------------------------------------------------
[ERROR] COMPILATION ERROR :
[INFO] -------------------------------------------------------------
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[15,26] cannot find symbol
  symbol:   variable sqlContext
  location: class SimpleEstimator
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[44,22] cannot find symbol
  symbol:   variable sqlContext
  location: class SimpleEstimator
[INFO] 2 errors
[INFO] -------------------------------------------------------------
[INFO] ------------------------------------------------------------------------
[INFO] BUILD FAILURE
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 1.567 s
[INFO] Finished at: 2015-09-16T16:54:20+01:00
[INFO] Final Memory: 36M/422M
[INFO] ------------------------------------------------------------------------
[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.1:compile (default-compile) on project simple-estimator: Compilation failure: Compilation failure:
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[15,26] cannot find symbol
[ERROR] symbol:   variable sqlContext
[ERROR] location: class SimpleEstimator
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[44,22] cannot find symbol
[ERROR] symbol:   variable sqlContext
[ERROR] location: class SimpleEstimator
[ERROR] -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
[ERROR] Re-run Maven using the -X switch to enable full debug logging.
[ERROR]
[ERROR] For more information about the errors and possible solutions, please read the following articles:
[ERROR] [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/MojoFailureException

更新

感谢@holden我确保添加这些行

// additional imports
import org.apache.spark.api.java.*;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.SQLContext;

// added these as starting lines in class
SparkConf conf = new SparkConf().setAppName("Simple Estimator");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);

进展了一些但现在我得到以下错误

[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.1:compile (default-compile) on project simple-estimator: Compilation failure
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[21,36] no suitable method found for createDataFrame(java.util.List<org.apache.spark.mllib.regression.LabeledPoint>,java.lang.Class<org.apache.spark.mllib.regression.LabeledPoint>)
[ERROR] method org.apache.spark.sql.SQLContext.<A>createDataFrame(org.apache.spark.rdd.RDD<A>,scala.reflect.api.TypeTags.TypeTag<A>) is not applicable
[ERROR] (cannot infer type-variable(s) A
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.rdd.RDD<A>))
[ERROR] method org.apache.spark.sql.SQLContext.<A>createDataFrame(scala.collection.Seq<A>,scala.reflect.api.TypeTags.TypeTag<A>) is not applicable
[ERROR] (cannot infer type-variable(s) A
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to scala.collection.Seq<A>))
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.rdd.RDD<org.apache.spark.sql.Row>,org.apache.spark.sql.types.StructType) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.rdd.RDD<org.apache.spark.sql.Row>)
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.api.java.JavaRDD<org.apache.spark.sql.Row>,org.apache.spark.sql.types.StructType) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.api.java.JavaRDD<org.apache.spark.sql.Row>)
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.rdd.RDD<?>,java.lang.Class<?>) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.rdd.RDD<?>)
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.api.java.JavaRDD<?>,java.lang.Class<?>) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.api.java.JavaRDD<?>)

错误引用的代码直接来自示例

DataFrame training = sqlContext.createDataFrame(Arrays.asList(
      new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
      new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
      new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
      new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))
    ), LabeledPoint.class);

1 个答案:

答案 0 :(得分:4)

这些示例通常不会创建sqlContextsc(或SparkContext),因为它们对于每个示例都是相同的。 http://spark.apache.org/docs/latest/sql-programming-guide.html有如何创建sqlContexthttp://spark.apache.org/docs/latest/quick-start.html如何创建sc(或SparkContext)。

您可能需要以下内容:

进口更多:

//Additional imports
import org.apache.spark.api.java.*;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.Function;

在主要方法的开头添加:

// In your method:
SparkConf conf = new SparkConf().setAppName("Simple Application");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);

根据您的更新,您遇到的第二个问题是创建DataFrame(再次遗漏在Java示例中)。您尝试使用的方法尚未实现(实际上我有一个待处理的pull请求在https://github.com/apache/spark/pull/8779实现类似的东西,尽管该版本需要Row和&amp; Schemas,我添加了一个JIRA {{3}跟踪为本地JavaBean解决方案添加它。)

值得庆幸的是,这个额外的步骤并不是我们将采取的所有代码:

   DataFrame test = sqlContext.createDataFrame(Arrays.asList(
      new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
      new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
      new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
    ), LabeledPoint.class);

而是:

   DataFrame test = sqlContext.createDataFrame(sc.parallelize(
      Arrays.asList(
        new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
        new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
        new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
    )), LabeledPoint.class);