Spark多类逻辑回归类编号和标签

时间:2016-03-27 13:01:54

标签: scala apache-spark logistic-regression

我正在从here为scala运行Spark的逻辑回归示例。

在培训部分:

val model = new LogisticRegressionWithLBFGS().setNumClasses(10).run(training)

类的数量设置为10.如果我的数据由3个标签组成,分别为5,12和20;它引发了一个例外,如

ERROR DataValidators: Classification labels should be in {0 to 9}. Found 6 invalid labels.

我知道我可以通过将classnum设置为大于最大类值来解决它。

是否可以在此类数据集上运行具有真实数量类的算法,而无需对标签值进行显式转换?

如果我使用高classnum运行它以使其工作,算法是否会预测不存在的类,例如17以上?

1 个答案:

答案 0 :(得分:1)

我认为您可以做的最好的事情是map您的培训数据并修改每个条目,并使用Map兑换labels 0.0, 1.0, 2.0, ..., n - 1,其中n = number of classes import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.Vectors val rdd = sc.parallelize(List( LabeledPoint(5.0, Vectors.dense(1,2)), LabeledPoint(12.0, Vectors.dense(1,3)), LabeledPoint(20.0, Vectors.dense(-1,4)))) val map = Map(5 -> 0.0, 12.0 -> 1.0, 20.0 -> 2.0) val trainingData = rdd.map{ case LabeledPoint(category, features) => LabeledPoint(map(category), features) } val model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(trainingData)

{{1}}