有没有办法打电话给'预测' Spark ML MultilayerPerceptronClassificationModel的方法?

时间:2017-04-04 19:12:26

标签: neural-network apache-spark-mllib

对于较旧的ML模型,例如DecisionTreeModel,可以加载存储的模型并将其直接应用于单个数据点(特征向量),如下所示:

val features: Vector = <some vector of floats representing feature values> 
val modelDT = DecisionTreeModel.load(sparkContext, <"some-path">)
val prediction = modelDT.predict(features)

对于MultilayerPerceptronClassificationModel,预测方法受到保护,无法调用。这些功能需要包装在数据集中,结果将作为DataFrame返回一行。这很麻烦,并且为一次分类一个点的实时系统增加了大量开销。

1 个答案:

答案 0 :(得分:0)

因此,如果您像这样加载模型:

no table exists

或您之前训练过的人。

然后以这种形式将没有已知标签的输入数据放入

model = MultilayerPerceptronClassificationModel.load("Path.model")

您还可以执行Vectors.dense或其他稀疏声明

您需要导入一些内容,例如:

test = spark.sparkContext.parallelize([Row(features=Vectors.sparse(4, {1: 1.0, 2: 0.0, 3: -1.0, 4: 1.0}))]).toDF()

result = model.transform(test).head().prediction

print(result)