在csv数据上应用Logistic回归时出现的问题

时间:2019-03-29 09:24:17

标签: scala apache-spark apache-spark-sql logistic-regression apache-spark-ml

我是机器学习的新手,并试图在本地模式下使用scala和spark学习它,我的要求是对Csv数据应用Logistic回归。

CSV数据示例:

id    normalized_total_spent_last_24_hours    normalized_merchant_fraud_risk  normalized_time_since_last_transaction  normalized_average_transaction  normalized_days_till_expiration normalized_transaction_time normalized_change_in_merchant_sales Amount  Class
0 -1.034133845    -0.513680076    -0.508604693    -2.196178501    -0.108862958    -1.061008629    0.285154155 135.75  0
1 -1.265759551    0.07327929  1.311443586 -0.734940773    1.450278841 -0.801969386    0.860978154 1.98    0
2 2.240560126 -1.509744002    -0.689632426    -1.622658556    -1.434514451    -0.419166831    -1.36019318 24  0
3 -22.32205074    -22.20892648    -8.997418067    3.396521112 1.155982154 -0.7160386  3.832327638 212 0
4 -0.522512757    0.81919506  1.777105544 1.013635885 0.306739941 -0.06426399 0.32108437  19.99   0
5 -2.089682661    0.849492313 0.790108223 -0.590925467    0.434408367 -0.805684103    0.523183012 3.99    0
6 -2.647158204    1.763548392 0.490936849 1.541377437 -0.949784452    -0.336538438    -0.706230268    9.46    0
7 -0.4630152  0.32577193  -0.139411116    -0.90596587 0.959955945 -0.809819817    1.687780067 149.95  0
8 -1.386557134    -1.320988511    1.579036707 -3.062784171    -0.437193393    -0.095830087    0.373993105 154 0
9 0.97974056  -0.420226839    1.036644229 0.580934286 -1.175975734    -0.445575941    -0.505391954    100.78  0
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification.LogisticRegression

object Test {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder
       .appName("Simple Application")
       .master("local[*]")
       .getOrCreate()
    val csvData = spark.read.format("csv")
       .option("header", "true")
       .option("inferschema", "true")
       .load("file:///F:/test.csv")

    csvData.printSchema

    var cols = Array("id", "normalized_total_spent_last_24_hours",
        "normalized_merchant_fraud_risk",
        "normalized_time_since_last_transaction",
        "normalized_average_transaction",
        "normalized_days_till_expiration",
        "normalized_transaction_time",
        "normalized_change_in_merchant_sales", "Amount")

    var assembler = new VectorAssembler()
       .setInputCols(cols)
       .setOutputCol("features")
    val pipeline=new Pipeline().setStages(Array(assembler))
    val df=pipeline.fit(csvData).transform(csvData)
    df.show(1)

    val splits=df.randomSplit(Array(0.8,0.2),seed=11L)
    val training=splits(0).cache()
    val test=splits(1)

    val lr=new LogisticRegression()
       .setMaxIter(10)
       .setRegParam(0.3)
       .setFeaturesCol("features")
       .setLabelCol("Class")
    val lrModel=lr.fit(training)
    val predictions=lrModel.transform(training)
    predictions.show()
  }
}

在上述数据集中,我想将label列用作Class,将其余列用作我的功能列。

我在上面的代码中收到以下错误: 下面是控制台堆栈跟踪:-

Caused by: java.io.NotSerializableException: scala.runtime.LazyRef
Serialization stack:
    - object not serializable (class: scala.runtime.LazyRef, value: LazyRef thunk)
    - element of array (index: 2)
    - array (class [Ljava.lang.Object;, size 3)
    - field (class: java.lang.invoke.SerializedLambda, name: capturedArgs, type: class [Ljava.lang.Object;)
    - object (class java.lang.invoke.SerializedLambda, SerializedLambda[capturingClass=class org.apache.spark.sql.catalyst.expressions.ScalaUDF, functionalInterfaceMethod=scala/Function1.apply:(Ljava/lang/Object;)Ljava/lang/Object;, implementation=invokeStatic org/apache/spark/sql/catalyst/expressions/ScalaUDF.$anonfun$f$2:(Lscala/Function1;Lorg/apache/spark/sql/catalyst/expressions/Expression;Lscala/runtime/LazyRef;Lorg/apache/spark/sql/catalyst/InternalRow;)Ljava/lang/Object;, instantiatedMethodType=(Lorg/apache/spark/sql/catalyst/InternalRow;)Ljava/lang/Object;, numCaptured=3])
    - writeReplace data (class: java.lang.invoke.SerializedLambda)
    - object (class org.apache.spark.sql.catalyst.expressions.ScalaUDF$$Lambda$1916/120999784, org.apache.spark.sql.catalyst.expressions.ScalaUDF$$Lambda$1916/120999784@2905b568)
    - field (class: org.apache.spark.sql.catalyst.expressions.ScalaUDF, name: f, type: interface scala.Function1)
    - object (class org.apache.spark.sql.catalyst.expressions.ScalaUDF, UDF(named_struct(id_double_vecAssembler_a001d143dede, cast(id#10 as double), normalized_total_spent_last_24_hours, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction, normalized_time_since_last_transaction#13, normalized_average_transaction, normalized_average_transaction#14, normalized_days_till_expiration, normalized_days_till_expiration#15, normalized_transaction_time, normalized_transaction_time#16, normalized_change_in_merchant_sales, normalized_change_in_merchant_sales#17, Amount, Amount#18, Class_double_vecAssembler_a001d143dede, cast(Class#19 as double))))
    - field (class: org.apache.spark.sql.catalyst.expressions.Alias, name: child, type: class org.apache.spark.sql.catalyst.expressions.Expression)
    - object (class org.apache.spark.sql.catalyst.expressions.Alias, UDF(named_struct(id_double_vecAssembler_a001d143dede, cast(id#10 as double), normalized_total_spent_last_24_hours, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction, normalized_time_since_last_transaction#13, normalized_average_transaction, normalized_average_transaction#14, normalized_days_till_expiration, normalized_days_till_expiration#15, normalized_transaction_time, normalized_transaction_time#16, normalized_change_in_merchant_sales, normalized_change_in_merchant_sales#17, Amount, Amount#18, Class_double_vecAssembler_a001d143dede, cast(Class#19 as double))) AS features#42)
    - element of array (index: 10)
    - array (class [Ljava.lang.Object;, size 11)
    - field (class: scala.collection.mutable.ArrayBuffer, name: array, type: class [Ljava.lang.Object;)
    - object (class scala.collection.mutable.ArrayBuffer, ArrayBuffer(id#10, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction#13, normalized_average_transaction#14, normalized_days_till_expiration#15, normalized_transaction_time#16, normalized_change_in_merchant_sales#17, Amount#18, Class#19, UDF(named_struct(id_double_vecAssembler_a001d143dede, cast(id#10 as double), normalized_total_spent_last_24_hours, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction, normalized_time_since_last_transaction#13, normalized_average_transaction, normalized_average_transaction#14, normalized_days_till_expiration, normalized_days_till_expiration#15, normalized_transaction_time, normalized_transaction_time#16, normalized_change_in_merchant_sales, normalized_change_in_merchant_sales#17, Amount, Amount#18, Class_double_vecAssembler_a001d143dede, cast(Class#19 as double))) AS features#42))
    - field (class: org.apache.spark.sql.execution.ProjectExec, name: projectList, type: interface scala.collection.Seq)
    - object (class org.apache.spark.sql.execution.ProjectExec, Project [id#10, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction#13, normalized_average_transaction#14, normalized_days_till_expiration#15, normalized_transaction_time#16, normalized_change_in_merchant_sales#17, Amount#18, Class#19, UDF(named_struct(id_double_vecAssembler_a001d143dede, cast(id#10 as double), normalized_total_spent_last_24_hours, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction, normalized_time_since_last_transaction#13, normalized_average_transaction, normalized_average_transaction#14, normalized_days_till_expiration, normalized_days_till_expiration#15, normalized_transaction_time, normalized_transaction_time#16, normalized_change_in_merchant_sales, normalized_change_in_merchant_sales#17, Amount, Amount#18, Class_double_vecAssembler_a001d143dede, cast(Class#19 as double))) AS features#42]
+- FileScan csv [id#10,normalized_total_spent_last_24_hours#11,normalized_merchant_fraud_risk#12,normalized_time_since_last_transaction#13,normalized_average_transaction#14,normalized_days_till_expiration#15,normalized_transaction_time#16,normalized_change_in_merchant_sales#17,Amount#18,Class#19] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/F:/test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int,normalized_total_spent_last_24_hours:double,normalized_merchant_fraud_risk:double,n...
)
    - field (class: org.apache.spark.sql.execution.SortExec, name: child, type: class org.apache.spark.sql.execution.SparkPlan)
    - object (class org.apache.spark.sql.execution.SortExec, Sort [id#10 ASC NULLS FIRST, normalized_total_spent_last_24_hours#11 ASC NULLS FIRST, normalized_merchant_fraud_risk#12 ASC NULLS FIRST, normalized_time_since_last_transaction#13 ASC NULLS FIRST, normalized_average_transaction#14 ASC NULLS FIRST, normalized_days_till_expiration#15 ASC NULLS FIRST, normalized_transaction_time#16 ASC NULLS FIRST, normalized_change_in_merchant_sales#17 ASC NULLS FIRST, Amount#18 ASC NULLS FIRST, Class#19 ASC NULLS FIRST, features#42 ASC NULLS FIRST], false, 0
+- Project [id#10, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction#13, normalized_average_transaction#14, normalized_days_till_expiration#15, normalized_transaction_time#16, normalized_change_in_merchant_sales#17, Amount#18, Class#19, UDF(named_struct(id_double_vecAssembler_a001d143dede, cast(id#10 as double), normalized_total_spent_last_24_hours, normalized_total_spent_last_24_hours#11, normalized_merchant_fraud_risk, normalized_merchant_fraud_risk#12, normalized_time_since_last_transaction, normalized_time_since_last_transaction#13, normalized_average_transaction, normalized_average_transaction#14, normalized_days_till_expiration, normalized_days_till_expiration#15, normalized_transaction_time, normalized_transaction_time#16, normalized_change_in_merchant_sales, normalized_change_in_merchant_sales#17, Amount, Amount#18, Class_double_vecAssembler_a001d143dede, cast(Class#19 as double))) AS features#42]
   +- FileScan csv [id#10,normalized_total_spent_last_24_hours#11,normalized_merchant_fraud_risk#12,normalized_time_since_last_transaction#13,normalized_average_transaction#14,normalized_days_till_expiration#15,normalized_transaction_time#16,normalized_change_in_merchant_sales#17,Amount#18,Class#19] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/F:/test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int,normalized_total_spent_last_24_hours:double,normalized_merchant_fraud_risk:double,n...
)
    - element of array (index: 0)
    - array (class [Ljava.lang.Object;, size 9)
    - element of array (index: 1)
    - array (class [Ljava.lang.Object;, size 3)
    - field (class: java.lang.invoke.SerializedLambda, name: capturedArgs, type: class [Ljava.lang.Object;)
    - object (class java.lang.invoke.SerializedLambda, SerializedLambda[capturingClass=class org.apache.spark.sql.execution.WholeStageCodegenExec, functionalInterfaceMethod=scala/Function2.apply:(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;, implementation=invokeStatic org/apache/spark/sql/execution/WholeStageCodegenExec.$anonfun$doExecute$4$adapted:(Lorg/apache/spark/sql/catalyst/expressions/codegen/CodeAndComment;[Ljava/lang/Object;Lorg/apache/spark/sql/execution/metric/SQLMetric;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, instantiatedMethodType=(Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, numCaptured=3])
    - writeReplace data (class: java.lang.invoke.SerializedLambda)
    - object (class org.apache.spark.sql.execution.WholeStageCodegenExec$$Lambda$1400/863366099, org.apache.spark.sql.execution.WholeStageCodegenExec$$Lambda$1400/863366099@191f4d65)
    at org.apache.spark.serializer.SerializationDebugger$.improveException(SerializationDebugger.scala:41)
    at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:46)
    at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:100)
    at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:400)
    ... 53 more

1 个答案:

答案 0 :(得分:0)

我观察到此问题主要是由于scala版本引起的,我当前的spark版本是2.4,更早的时候我使用的是scala版本2.12.3,后来解决了这个问题,我尝试在我的项目中用2.11替换scala库版本构建路径,我也观察到这个问题主要发生在我们在本地模式下使用spark和maven和eclipse的scala时,我希望这个答案可以帮助面临相同问题,欢呼和快乐编码的人。