在迁移到1.6.1后,训练火花ml线性回归模型失败

时间:2016-06-03 14:58:01

标签: apache-spark apache-spark-ml

我使用spark-ml来训练线性回归模型。 它与spark版本1.5.2完美配合,但现在1.6.1我得到以下错误:

java.lang.AssertionError: assertion failed: lapack.dppsv returned 228.

它似乎与某些低级线性代数库有关,但它在Spark版本更新之前运行良好。

在两个版本中,我都会在训练开始之前收到相同的警告,说它无法加载BLAS和LAPACK

[Executor task launch worker-6] com.github.fommil.netlib.BLAS - Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
[Executor task launch worker-6] com.github.fommil.netlib.BLAS - Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
[main]    com.github.fommil.netlib.LAPACK - Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
[main]    com.github.fommil.netlib.LAPACK - Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK

这是一个最小的代码:

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;

public class Application {

    public static void main(String args[]) {

        // create context
        JavaSparkContext javaSparkContext = new JavaSparkContext("local[*]", "CalculCote");
        SQLContext sqlContext = new SQLContext(javaSparkContext);

        // describre fields
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("brand", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("commercial_name", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("mileage", DataTypes.IntegerType, true));
        fields.add(DataTypes.createStructField("price", DataTypes.DoubleType, true));

        // load dataframe from file
        DataFrame df = sqlContext.read().format("com.databricks.spark.csv") //
                .option("header", "true") //
                .option("InferSchema", "false") //
                .option("delimiter", ";") //
                .schema(DataTypes.createStructType(fields)) //
                .load("input.csv").persist();

        // show first rows
        df.show();

        // indexers and encoders for non numerical values
        StringIndexer brandIndexer = new StringIndexer() //
                .setInputCol("brand") //
                .setOutputCol("brandIndex");

        OneHotEncoder brandEncoder = new OneHotEncoder() //
                .setInputCol("brandIndex") //
                .setOutputCol("brandVec");

        StringIndexer commNameIndexer = new StringIndexer() //
                .setInputCol("commercial_name") //
                .setOutputCol("commNameIndex");

        OneHotEncoder commNameEncoder = new OneHotEncoder() //
                .setInputCol("commNameIndex") //
                .setOutputCol("commNameVec");

        // model predictors
        VectorAssembler predictors = new VectorAssembler() //
                .setInputCols(new String[] { "brandVec", "commNameVec", "mileage" }) //
                .setOutputCol("features");

        // train model
        LinearRegression lr = new LinearRegression().setLabelCol("price");

        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { //
                brandIndexer, brandEncoder, commNameIndexer, commNameEncoder, predictors, lr });

        PipelineModel pm = pipeline.fit(df);

        DataFrame result = pm.transform(df);

        result.show();
    }
}

和input.csv数据

brand;commercial_name;mileage;price
APRILIA;ATLANTIC 125;18237;1400
BMW;R1200 GS;10900;12400
HONDA;CB 1000;58225;4250
HONDA;CB 1000;1780;7610
HONDA;CROSSRUNNER 800;2067;11490
KAWASAKI;ER-6F 600;51600;2010
KAWASAKI;VERSYS 1000;5900;13900
KAWASAKI;VERSYS 650;3350;6200
KTM;SUPER DUKE 990;36420;4760

pom.xml

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <groupId>test</groupId>
    <artifactId>sparkmigration</artifactId>
    <packaging>jar</packaging>
    <name>sparkmigration</name>
    <version>0.0.1</version>
    <url>http://maven.apache.org</url>



    <properties>
        <java.version>1.8</java.version>
        <spark.version>1.6.1</spark.version>
<!--        <spark.version>1.5.2</spark.version> -->
        <spark.csv.version>1.3.0</spark.csv.version>
        <slf4j.version>1.7.2</slf4j.version>
        <logback.version>1.0.9</logback.version>
    </properties>


    <dependencies>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.11</artifactId>
            <exclusions>
                <exclusion>
                    <groupId>org.slf4j</groupId>
                    <artifactId>slf4j-log4j12</artifactId>
                </exclusion>
            </exclusions>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>com.databricks</groupId>
            <artifactId>spark-csv_2.11</artifactId>
            <version>${spark.csv.version}</version>
        </dependency>


        <!-- Logs -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>log4j-over-slf4j</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>jcl-over-slf4j</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-core</artifactId>
            <version>${logback.version}</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>${logback.version}</version>
        </dependency>   

    </dependencies>


    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.2</version>
                <configuration>
                    <source>${java.version}</source>
                    <target>${java.version}</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

1 个答案:

答案 0 :(得分:0)

问题已解决(感谢apache spark邮件列表)

由于火花1.6,线性回归模型被设置为“自动”,在一些编码中(特征<= 4096,没有弹性网络参数集......),使用WSL算法代替L-BFGS。

我将求解器强制为l-bfgs并且有效

LinearRegression lr = new LinearRegression().setLabelCol("price").setSolver("l-bfgs");