我想对结果进行正确的解释。
源数据集(标签字段仅包含0和1)
scala> mlsrc.show()
+-----+---+---+---+
|label| f1| f2| f3|
+-----+---+---+---+
| 0.0|3.0|3.0| 1|
| 0.0|3.0|3.0| 1|
| 0.0|3.0|3.0| 2|
| 0.0|3.0|2.0| 1|
| 0.0|2.0|3.0| 2|
| 0.0|1.0|1.0| 3|
| 0.0|3.0|3.0| 1|
| 0.0|1.0|1.0| 2|
| 0.0|1.0|2.0| 1|
| 0.0|3.0|3.0| 1|
| 0.0|3.0|3.0| 1|
| 0.0|3.0|3.0| 1|
| 0.0|3.0|3.0| 1|
| 0.0|1.0|1.0| 2|
| 0.0|3.0|3.0| 2|
| 0.0|1.0|3.0| 2|
| 0.0|3.0|3.0| 1|
| 0.0|3.0|3.0| 1|
| 0.0|1.0|2.0| 3|
| 0.0|1.0|1.0| 3|
+-----+---+---+---+
将其转换为SparkML的libsvm格式。
scala> data.show(5)
+-----+---+---+---+-------------+
|label| f1| f2| f3| features|
+-----+---+---+---+-------------+
| 0.0|3.0|3.0| 1|[3.0,3.0,1.0]|
| 0.0|3.0|3.0| 1|[3.0,3.0,1.0]|
| 0.0|3.0|3.0| 2|[3.0,3.0,2.0]|
| 0.0|3.0|2.0| 1|[3.0,2.0,1.0]|
| 0.0|2.0|3.0| 2|[2.0,3.0,2.0]|
+-----+---+---+---+-------------+
并运行下一个代码。
val layers = Array[Int](3, 5, 5, 2)
val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setLabelCol("label").setFeaturesCol("features").setBlockSize(128).setSeed(1234L).setMaxIter(10)
val model = trainer.fit(train)
val result = model.transform(test)
result.show()
(3,5,5,2),因为我在功能中包含3个元素,并且知道只有2个可能的输出0,1。
尺寸为3的输入层,尺寸为5和5的两个中间层以及尺寸为2(类)的输出
结果如下:
+-----+---+---+---+-------------+--------------------+--------------------+----------+
|label| f1| f2| f3| features| rawPrediction| probability|prediction|
+-----+---+---+---+-------------+--------------------+--------------------+----------+
| 0.0|1.0|1.0| 3|[1.0,1.0,3.0]|[-1.7545448222707...|[0.46074576139667...| 1.0|
| 0.0|2.0|3.0| 2|[2.0,3.0,2.0]|[-1.7361574163221...|[0.46435300578321...| 1.0|
| 0.0|3.0|2.0| 1|[3.0,2.0,1.0]|[-1.6983478426376...|[0.47152530968704...| 1.0|
| 0.0|1.0|1.0| 2|[1.0,1.0,2.0]|[-1.7461462437059...|[0.46441191948172...| 1.0|
| 0.0|1.0|2.0| 1|[1.0,2.0,1.0]|[-1.7296803383526...|[0.47066804705632...| 1.0|
+-----+---+---+---+-------------+--------------------+--------------------+----------+
在列概率中,我具有大小为2的VectorUDT类型(元素的总和等于1),此处仅可见第一个元素,对于0.46,第二个值为0.54
我的问题: 我用以下方式解释结果: 对于特征(1.0,1.0,3)的预测值1(来自列预测)具有 概率= 0.54 我说得对吗?