我使用spark-ml来训练线性回归模型。 它与spark版本1.5.2完美配合,但现在1.6.1我得到以下错误:
java.lang.AssertionError: assertion failed: lapack.dppsv returned 228.
它似乎与某些低级线性代数库有关,但它在Spark版本更新之前运行良好。
在两个版本中,我都会在训练开始之前收到相同的警告,说它无法加载BLAS和LAPACK
[Executor task launch worker-6] com.github.fommil.netlib.BLAS - Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
[Executor task launch worker-6] com.github.fommil.netlib.BLAS - Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
[main] com.github.fommil.netlib.LAPACK - Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
[main] com.github.fommil.netlib.LAPACK - Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK
这是一个最小的代码:
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
public class Application {
public static void main(String args[]) {
// create context
JavaSparkContext javaSparkContext = new JavaSparkContext("local[*]", "CalculCote");
SQLContext sqlContext = new SQLContext(javaSparkContext);
// describre fields
List<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("brand", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("commercial_name", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("mileage", DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("price", DataTypes.DoubleType, true));
// load dataframe from file
DataFrame df = sqlContext.read().format("com.databricks.spark.csv") //
.option("header", "true") //
.option("InferSchema", "false") //
.option("delimiter", ";") //
.schema(DataTypes.createStructType(fields)) //
.load("input.csv").persist();
// show first rows
df.show();
// indexers and encoders for non numerical values
StringIndexer brandIndexer = new StringIndexer() //
.setInputCol("brand") //
.setOutputCol("brandIndex");
OneHotEncoder brandEncoder = new OneHotEncoder() //
.setInputCol("brandIndex") //
.setOutputCol("brandVec");
StringIndexer commNameIndexer = new StringIndexer() //
.setInputCol("commercial_name") //
.setOutputCol("commNameIndex");
OneHotEncoder commNameEncoder = new OneHotEncoder() //
.setInputCol("commNameIndex") //
.setOutputCol("commNameVec");
// model predictors
VectorAssembler predictors = new VectorAssembler() //
.setInputCols(new String[] { "brandVec", "commNameVec", "mileage" }) //
.setOutputCol("features");
// train model
LinearRegression lr = new LinearRegression().setLabelCol("price");
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { //
brandIndexer, brandEncoder, commNameIndexer, commNameEncoder, predictors, lr });
PipelineModel pm = pipeline.fit(df);
DataFrame result = pm.transform(df);
result.show();
}
}
和input.csv数据
brand;commercial_name;mileage;price
APRILIA;ATLANTIC 125;18237;1400
BMW;R1200 GS;10900;12400
HONDA;CB 1000;58225;4250
HONDA;CB 1000;1780;7610
HONDA;CROSSRUNNER 800;2067;11490
KAWASAKI;ER-6F 600;51600;2010
KAWASAKI;VERSYS 1000;5900;13900
KAWASAKI;VERSYS 650;3350;6200
KTM;SUPER DUKE 990;36420;4760
pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>test</groupId>
<artifactId>sparkmigration</artifactId>
<packaging>jar</packaging>
<name>sparkmigration</name>
<version>0.0.1</version>
<url>http://maven.apache.org</url>
<properties>
<java.version>1.8</java.version>
<spark.version>1.6.1</spark.version>
<!-- <spark.version>1.5.2</spark.version> -->
<spark.csv.version>1.3.0</spark.csv.version>
<slf4j.version>1.7.2</slf4j.version>
<logback.version>1.0.9</logback.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
</exclusions>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>com.databricks</groupId>
<artifactId>spark-csv_2.11</artifactId>
<version>${spark.csv.version}</version>
</dependency>
<!-- Logs -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>log4j-over-slf4j</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>jcl-over-slf4j</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.2</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
答案 0 :(得分:0)
问题已解决(感谢apache spark邮件列表)
由于火花1.6,线性回归模型被设置为“自动”,在一些编码中(特征<= 4096,没有弹性网络参数集......),使用WSL算法代替L-BFGS。
我将求解器强制为l-bfgs并且有效
LinearRegression lr = new LinearRegression().setLabelCol("price").setSolver("l-bfgs");