使用经过培训的Spark ML模型进行实时预测

时间:2017-09-17 11:17:04

标签: apache-spark pyspark apache-spark-ml

我们目前正在测试基于Spark在Python中实现LDA的预测引擎: https://spark.apache.org/docs/2.2.0/ml-clustering.html#latent-dirichlet-allocation-lda https://spark.apache.org/docs/2.2.0/api/python/pyspark.ml.html#pyspark.ml.clustering.LDA (我们使用的是pyspark.ml包,而不是pyspark.mllib)

我们能够成功地在Spark群集上训练模型(使用Google Cloud Dataproc)。现在我们尝试使用该模型作为API提供实时预测(例如烧瓶应用程序)。

实现目标的最佳方法是什么?

我们的主要痛点是,我们似乎需要恢复整个Spark环境,以便加载训练好的模型并运行转换。 到目前为止,我们已尝试为每个收到的请求以本地模式运行Spark,但这种方法给了我们:

  1. 表现不佳(启动SparkSession的时间,加载模型,运行转换......)
  2. 可扩展性差(无法处理并发请求)
  3. 整个方法看起来相当沉重,是否会有更简单的替代方案,甚至根本不需要暗示Spark?

    Bellow是训练和预测步骤的简化代码。

    培训代码

    def train(input_dataset):   
        conf = pyspark.SparkConf().setAppName("lda-train")
        spark = SparkSession.builder.config(conf=conf).getOrCreate()
    
        # Generate count vectors
        count_vectorizer = CountVectorizer(...)
        vectorizer_model = count_vectorizer.fit(input_dataset)
        vectorized_dataset = vectorizer_model.transform(input_dataset)
    
        # Instantiate LDA model
        lda = LDA(k=100, maxIter=100, optimizer="em", ...)
    
        # Train LDA model
        lda_model = lda.fit(vectorized_dataset)
    
        # Save models to external storage
        vectorizer_model.write().overwrite().save("gs://...")
        lda_model.write().overwrite().save("gs://...")
    

    预测代码

    def predict(input_query):
        conf = pyspark.SparkConf().setAppName("lda-predict").setMaster("local")
        spark = SparkSession.builder.config(conf=conf).getOrCreate()
    
        # Load models from external storage
        vectorizer_model = CountVectorizerModel.load("gs://...")
        lda_model = DistributedLDAModel.load("gs://...")
    
        # Run prediction on the input data using the loaded models
        vectorized_query = vectorizer_model.transform(input_query)
        transformed_query = lda_model.transform(vectorized_query)
    
        ...
    
        spark.stop()
    
        return transformed_query
    

1 个答案:

答案 0 :(得分:2)

如果您已经在spark中使用经过培训的机器学习模型,则可以使用Hydroshpere Mist来使用rest api 来提供模型(测试或预测),而无需创建Spark Context 。这将使您无需重新创建火花环境,仅依靠web services进行预测

参见: