我构建了一个Spark(2.2.0)ML管道,输出为CrossValidatorModel
,并使用save
方法编写管道。我想使用Play(2.6.0)框架和Scala(2.11.11)来提供这个预先训练好的模型,但是我在解决如何使用Spark和Play时遇到了一些问题,和/或什么&#39 ; s是加载模型的最佳方式。
关于我的Play设置,我的文件结构的相关内容非常简单:
app/
controllers/
HomeController.scala
ModelScorer.scala
models/
Passenger.scala
Prediction.scala
conf/
routes
其中Passenger
和Prediction
分别是表示模型输入和输出的Case Classes。 HomeController
表示将以JSON格式提取POST
请求的逻辑,将内容解析为Seq[Passenger]
,并将其提供给ModelScorer.predict(data)
,如下所示。
// HomeController.scala
package controllers
import javax.inject._
import models.{Passenger, Prediction}
import play.api.mvc._
import play.api.libs.json._
import play.api.libs.functional.syntax._
@Singleton
class HomeController @Inject()(cc: ControllerComponents) extends AbstractController(cc) {
implicit val passengerReads: Reads[Passenger] = (
... // Various mappings
)(Passenger.apply _)
implicit val predictionWrites: Writes[Prediction] = (
... // Various mappings
)(unlift(Prediction.unapply))
def myEndpoint() = Action { implicit request: Request[AnyContent] =>
val inputData: JsValue = request.body.asJson.get
val passengers: Seq[Passenger] = inputData.validate[Seq[Passenger]].get
val predictions: Seq[Prediction] = ModelScorer.predict(passengers)
val outputData: JsValue = Json.toJson(predictions)
Ok(outputData)
}
}
要对预测进行评分,ModelScorer
对象用于初始化SparkSession
,使用Guava Cache加载模型,然后使用predict
方法运行逻辑以将预测返回到{ {1}}。据我所知,有问题的行是HomeController
,它告诉我Spark初始化有问题,但我不确定如何继续。
val ds: Dataset[Passenger] = passengers.toDS
我的// ModelScorer.scala
package controllers
import com.google.common.cache.{CacheBuilder, CacheLoader}
import models.{Passenger, Prediction}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.tuning.CrossValidatorModel
import org.apache.spark.sql.{Dataset, SparkSession}
object ModelScorer {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val spark = SparkSession.builder
.master("local[*]")
.appName("ml-server")
.getOrCreate()
import spark.implicits._
val modelCache = CacheBuilder.newBuilder()
.build(
new CacheLoader[String, CrossValidatorModel] {
def load(path: String): CrossValidatorModel = {
CrossValidatorModel.load(path)
}
}
)
val model: CrossValidatorModel = modelCache.get("trained-cv-pipeline")
def predict(passengers: Seq[Passenger]): Seq[Prediction] = {
val ds: Dataset[Passenger] = passengers.toDS
val predictions: Seq[Prediction] = model.transform(ds)
.select("name","probability","prediction")
.withColumnRenamed("prediction","survives")
.as[Prediction]
.collect
.toSeq
predictions
}
}
中的必需依赖项是:
build.sbt
在libraryDependencies ++= Seq(
guice
, "org.scalatestplus.play" %% "scalatestplus-play" % "3.1.0" % Test
, "org.apache.spark" %% "spark-core" % "2.2.0"
, "org.apache.spark" %% "spark-sql" % "2.2.0"
, "org.apache.spark" %% "spark-mllib" % "2.2.0"
, "org.apache.hadoop" % "hadoop-client" % "2.7.2"
)
dependencyOverrides ++= Set(
"com.fasterxml.jackson.core" % "jackson-databind" % "2.6.5"
, "com.google.guava" % "guava" % "19.0"
)
向POST
提出必要http://localhost:9000/myEndpoint
请求后,Stacktrace为:
JSON
我最好跟踪问题来创建@752mgi3ib - Internal server error, for (POST) [/myEndpoint] ->
play.api.http.HttpErrorHandlerExceptions$$anon$1: Execution
exception[[ScalaReflectionException: class models.Passenger in
JavaMirror with
DependencyClassLoader{file:/Users/XXXX/.ivy2/cache/org.scala-
lang/scala-library/jars/scala-library-2.11.11.jar,
...
... // Many, many lines
...
... :/Library/Java/JavaVirtualMachines/jdk1.8.0_131.jdk/Contents/Home/jre/classes] not found.
at play.api.http.HttpErrorHandlerExceptions$.throwableToUsefulException(HttpErrorHandler.scala:255)
at play.api.http.DefaultHttpErrorHandler.onServerError(HttpErrorHandler.scala:180)
at play.core.server.AkkaHttpServer$$anonfun$13$$anonfun$apply$1.applyOrElse(AkkaHttpServer.scala:252)
at play.core.server.AkkaHttpServer$$anonfun$13$$anonfun$apply$1.applyOrElse(AkkaHttpServer.scala:251)
at scala.concurrent.Future$$anonfun$recoverWith$1.apply(Future.scala:346)
at scala.concurrent.Future$$anonfun$recoverWith$1.apply(Future.scala:345)
at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:36)
at play.api.libs.streams.Execution$trampoline$.execute(Execution.scala:70)
at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:44)
at scala.concurrent.impl.Promise$DefaultPromise.scala$concurrent$impl$Promise$
DefaultPromise$$dispatchOrAddCallback(Promise.scala:284) Caused by: scala.ScalaReflectionException: class models.Passenger in JavaMirror with DependencyClassLoader{file:
...
... // Many, many lines
...
... :/Library/Java/JavaVirtualMachines/jdk1.8.0_131.jdk/Contents/Home/jre/classes] not found.
at scala.reflect.internal.Mirrors$RootsBase.staticClass(Mirrors.scala:123)
at scala.reflect.internal.Mirrors$RootsBase.staticClass(Mirrors.scala:22)
at controllers.ModelScorer$$typecreator3$1.apply(ModelScorer.scala:34)
at scala.reflect.api.TypeTags$WeakTypeTagImpl.tpe$lzycompute(TypeTags.scala:232)
at scala.reflect.api.TypeTags$WeakTypeTagImpl.tpe(TypeTags.scala:232)
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:49)
at org.apache.spark.sql.Encoders$.product(Encoders.scala:275)
at org.apache.spark.sql.LowPrioritySQLImplicits$class.newProductEncoder(SQLImplicits.scala:233)
at org.apache.spark.sql.SQLImplicits.newProductEncoder(SQLImplicits.scala:33)
at controllers.ModelScorer$.predict(ModelScorer.scala:34)
中的数据集,更具体地说是ModelScorer.predict(passengers)
行,尽管我可以使用sbt控制台在REPL中运行该行,这让我觉得将Spark集成到Play中存在一些问题。有点不知所措如何继续,任何和所有指导赞赏!